1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes for different algorithms of reduction and broadcasting.""" 16 17import collections 18import copy 19import multiprocessing.dummy 20import multiprocessing.pool 21import threading 22 23import six 24 25from tensorflow.python.client import device_lib 26from tensorflow.python.distribute import collective_util 27from tensorflow.python.distribute import cross_device_utils 28from tensorflow.python.distribute import device_util 29from tensorflow.python.distribute import distribute_utils 30from tensorflow.python.distribute import ps_values 31from tensorflow.python.distribute import reduce_util 32from tensorflow.python.distribute import tpu_values 33from tensorflow.python.distribute import values as value_lib 34from tensorflow.python.distribute import values_util 35from tensorflow.python.eager import context 36from tensorflow.python.eager import def_function 37from tensorflow.python.framework import indexed_slices 38from tensorflow.python.framework import kernels 39from tensorflow.python.framework import ops 40from tensorflow.python.framework import tensor_util 41from tensorflow.python.ops import array_ops 42from tensorflow.python.ops import math_ops 43from tensorflow.python.ops import resource_variable_ops 44from tensorflow.python.platform import tf_logging as logging 45from tensorflow.python.util import nest 46from tensorflow.python.util.tf_export import tf_export 47from tensorflow.tools.docs import doc_controls 48 49 50def check_destinations(destinations): 51 """Checks whether `destinations` is not empty. 52 53 Args: 54 destinations: a `DistributedValues`, variable, or string object. 55 56 Returns: 57 Boolean which is True if `destinations` is not empty. 58 """ 59 # Calling bool() on a ResourceVariable is not allowed. 60 if isinstance(destinations, 61 (resource_variable_ops.BaseResourceVariable, ops.Tensor)): 62 return bool(destinations.device) 63 return bool(destinations) 64 65 66def validate_destinations(destinations): 67 """Validates the `destination` is one of expected types.""" 68 if not isinstance( 69 destinations, 70 (value_lib.DistributedValues, ops.Tensor, indexed_slices.IndexedSlices, 71 ps_values.AggregatingVariable, six.string_types, 72 tpu_values.TPUMirroredVariable 73 )) and not resource_variable_ops.is_resource_variable(destinations): 74 raise ValueError("destinations must be one of a `DistributedValues` object," 75 " a tf.Variable object, or a device string.") 76 77 if not check_destinations(destinations): 78 raise ValueError("destinations can not be empty") 79 80 81def reduce_non_distributed_value(reduce_op, 82 value, 83 destinations, 84 num_replicas_in_graph, 85 canonicalize_devices=True): 86 """Reduce a non-DistributedValue `value` to `destinations`.""" 87 if isinstance(value, value_lib.DistributedValues): 88 raise ValueError("You are passing a `DistributedValues` to " 89 "`reduce_non_distributed_value`, which is not allowed.") 90 91 # If the same value is present on all replicas then the PerReplica value will 92 # be a single value. We also handle the case when `value` is a single value 93 # and equal to 0. 94 # TODO:(b/138823479): handle the tensor value properly. 95 if not tensor_util.is_tf_type(value) and value == 0: 96 return 0 97 # If there is only a single value and the reduce op is MEAN, 98 # that value should be on all destinations. 99 if reduce_op == reduce_util.ReduceOp.MEAN: 100 return value 101 elif num_replicas_in_graph != 1: 102 # We do not support a reduce op of SUM if the value is the same across 103 # all replicas. We call this as part of assign functions for 104 # MirroredVariables and summing up identical values across replicas is not 105 # clearly defined. 106 raise ValueError("A non-DistributedValues value %s cannot be reduced with " 107 "the given reduce op %s." % (value, reduce_op)) 108 else: 109 validate_destinations(destinations) 110 return simple_broadcast( 111 value, destinations, canonicalize_devices=canonicalize_devices) 112 113 114def _make_tensor_into_per_replica(input_tensor): 115 """Converts a single tensor into a PerReplica object.""" 116 if isinstance(input_tensor, value_lib.DistributedValues): 117 return input_tensor 118 119 # If input is not a Tensor, convert it to a Tensor first. 120 if not tensor_util.is_tensor(input_tensor): 121 input_tensor = ops.convert_to_tensor(input_tensor) 122 123 if hasattr(input_tensor, "device"): 124 return value_lib.PerReplica((input_tensor,)) 125 126 raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object " 127 "because it doesn't have device set.") 128 129 130def _normalize_value_destination_pairs(value_destination_pairs): 131 """Converts each tensor into a PerReplica object in the input list.""" 132 result = [] 133 134 value_destination_pairs = list(value_destination_pairs) 135 136 if not isinstance(value_destination_pairs, (list, tuple)): 137 raise ValueError("`value_destination_pairs` should be a list or tuple") 138 for pair in value_destination_pairs: 139 if not isinstance(pair, tuple): 140 raise ValueError( 141 "Each element of `value_destination_pairs` should be a tuple.") 142 if len(pair) != 2: 143 raise ValueError("Each element of `value_destination_pairs` should be a " 144 "tuple of size 2.") 145 146 per_replica = _make_tensor_into_per_replica(pair[0]) 147 result.append((per_replica, pair[1])) 148 return result 149 150 151def _validate_value_destination_pairs(value_destination_pairs): 152 """Validates value_destination_pairs are valid.""" 153 # TODO(yuefengz): raise exceptions instead of returning False. 154 if not value_destination_pairs: return False 155 if not isinstance(value_destination_pairs, (list, tuple)): return False 156 if not all(isinstance(pair, tuple) for pair in value_destination_pairs): 157 return False 158 if not all(isinstance(v[0], value_lib.PerReplica) 159 for v in value_destination_pairs): 160 return False 161 return True 162 163 164# TODO(yuefengz): consider calling this function in the caller of 165# CrossDeviceOps. 166def get_devices_from(destinations, canonicalize_devices=True): 167 if isinstance(destinations, value_lib.DistributedValues): 168 return destinations._devices # pylint: disable=protected-access 169 if canonicalize_devices: 170 if isinstance(destinations, six.string_types): 171 return (device_util.resolve(destinations),) 172 return (device_util.resolve(destinations.device),) 173 174 # Let placer canonicalize and resolve destination devices. 175 if isinstance(destinations, six.string_types): 176 return (device_util.canonicalize_without_job_and_task(destinations),) 177 return (device_util.canonicalize_without_job_and_task(destinations.device),) 178 179 180def _devices_match(left, right, canonicalize_devices=True): 181 return left is right or set(get_devices_from( 182 left, canonicalize_devices)) == set( 183 get_devices_from(right, canonicalize_devices)) 184 185 186def _all_devices_match(value_destination_pairs, canonicalize_devices=True): 187 if not all( 188 _devices_match(v, d, canonicalize_devices) 189 for v, d in value_destination_pairs): 190 return False 191 if not all( 192 _devices_match(v, value_destination_pairs[0][0], canonicalize_devices) 193 for v, _ in value_destination_pairs[1:]): 194 return False 195 return True 196 197 198def simple_broadcast(value, 199 destinations, 200 always_mirrored=False, 201 canonicalize_devices=True): 202 """Broadcast `value` to `destinations` using simple copies.""" 203 devices = get_devices_from(destinations, canonicalize_devices) 204 if len(devices) == 1 and not always_mirrored: 205 return cross_device_utils.copy_tensor_or_indexed_slices_to_device( 206 value, devices[0]) 207 else: 208 value_updates = [] 209 for d in devices: 210 value_updates.append( 211 cross_device_utils.copy_tensor_or_indexed_slices_to_device(value, d)) 212 return distribute_utils.regroup(value_updates, 213 wrap_class=value_lib.Mirrored) 214 215 216def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, 217 reduce_op): 218 """Reduces the value by accumulation_fn and reduce_op.""" 219 all_values = per_replica_value.values 220 if not all_values: 221 raise ValueError("`per_replica_value` must be non-empty") 222 count = len(all_values) 223 224 with ops.device(reduce_to_device): 225 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 226 reduced = cross_device_utils.aggregate_tensors_or_indexed_slices( 227 all_values, accumulation_fn) 228 if reduce_op == reduce_util.ReduceOp.MEAN: 229 reduced = cross_device_utils.divide_by_n_tensors_or_indexed_slices( 230 reduced, count) 231 elif reduce_op != reduce_util.ReduceOp.SUM: 232 raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.") 233 return reduced 234 235 236def _simple_gather(per_replica_value, reduce_to_device, axis): 237 """Concatenate all values in the DistributedValues input and return.""" 238 all_values = per_replica_value.values 239 if not all_values: 240 raise ValueError("`per_replica_value` must be non-empty") 241 242 with ops.device(reduce_to_device): 243 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 244 gathered = array_ops.concat(all_values, axis) 245 return gathered 246 247 248@tf_export("distribute.CrossDeviceOps") 249class CrossDeviceOps(object): 250 """Base class for cross-device reduction and broadcasting algorithms. 251 252 The main purpose of this class is to be passed to 253 `tf.distribute.MirroredStrategy` in order to choose among different cross 254 device communication implementations. Prefer using the methods of 255 `tf.distribute.Strategy` instead of the ones of this class. 256 257 Implementations: 258 * `tf.distribute.ReductionToOneDevice` 259 * `tf.distribute.NcclAllReduce` 260 * `tf.distribute.HierarchicalCopyAllReduce` 261 """ 262 263 def __init__(self): 264 self._canonicalize_devices = True 265 pass 266 267 @property 268 def _num_between_graph_workers(self): 269 # Returns 1 by default, the value may be overridden by sub classes. 270 return 1 271 272 def reduce(self, reduce_op, per_replica_value, destinations, options=None): 273 """Reduce `per_replica_value` to `destinations`. 274 275 See `tf.distribute.StrategyExtended.reduce_to`. This can only be called in 276 the cross-replica context. 277 278 Args: 279 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be 280 combined. 281 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` 282 like object. 283 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 284 `tf.Tensor` alike object, or a device string. It specifies the devices 285 to reduce to. To perform an all-reduce, pass the same to `value` and 286 `destinations`. Note that if it's a `tf.Variable`, the value is reduced 287 to the devices of that variable, and this method doesn't update the 288 variable. 289 options: a `tf.distribute.experimental.CommunicationOptions`. See 290 `tf.distribute.experimental.CommunicationOptions` for details. 291 292 Returns: 293 A `tf.Tensor` or `tf.distribute.DistributedValues`. 294 295 Raises: 296 ValueError: if per_replica_value can't be converted to a 297 `tf.distribute.DistributedValues` or if destinations is not a string, 298 `tf.Variable` or `tf.distribute.DistributedValues`. 299 """ 300 if options is None: 301 options = collective_util.Options() 302 303 per_replica_value = _make_tensor_into_per_replica(per_replica_value) 304 305 validate_destinations(destinations) 306 307 # Shortcut if `per_replica_value` only contains one value. 308 if self._num_between_graph_workers == 1 and len( 309 per_replica_value.values) == 1 and _devices_match( 310 per_replica_value, destinations, self._canonicalize_devices): 311 with ops.device(per_replica_value.values[0].device): 312 v = array_ops.identity(per_replica_value.values[0]) 313 return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored) 314 315 if options is None: 316 options = collective_util.Options() 317 return self.reduce_implementation(reduce_op, per_replica_value, 318 destinations, options) 319 320 def _gather(self, per_replica_value, destinations, axis, options=None): 321 """Gather `per_replica_value` to `destinations`. 322 323 Args: 324 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` 325 like object. 326 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 327 `tf.Tensor` alike object, or a device string. It specifies the devices 328 to gather to. To perform an all-gather, pass the same to `value` and 329 `destinations`. Note that if it's a `tf.Variable`, the value is gathered 330 to the devices of that variable, and this method doesn't update the 331 variable. 332 axis: specifies the dimension to gather along within each replica's 333 tensor. 334 options: a `tf.distribute.experimental.CommunicationOptions`. See 335 `tf.distribute.experimental.CommunicationOptions` for details. 336 337 Returns: 338 A `tf.Tensor` or `tf.distribute.DistributedValues` 339 340 Raises: 341 ValueError: if per_replica_value can't be converted to a 342 `tf.distribute.DistributedValues` or if destinations is not a string, 343 `tf.Variable` or `tf.distribute.DistributedValues`. 344 """ 345 if isinstance(per_replica_value, indexed_slices.IndexedSlices): 346 raise NotImplementedError("gather/all_gather does not support " 347 "IndexedSlices") 348 if options is None: 349 options = collective_util.Options() 350 351 per_replica_value = _make_tensor_into_per_replica(per_replica_value) 352 353 validate_destinations(destinations) 354 355 # Shortcut if `per_replica_value` only contains one value. 356 if self._num_between_graph_workers == 1 and len( 357 per_replica_value.values) == 1 and _devices_match( 358 per_replica_value, destinations, self._canonicalize_devices): 359 with ops.device(per_replica_value.values[0].device): 360 v = array_ops.identity(per_replica_value.values[0]) 361 return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored) 362 363 return self._gather_implementation(per_replica_value, destinations, axis, 364 options) 365 366 def _gather_implementation(self, per_replica_value, destinations, axis, 367 options): 368 """Implementation of `gather` method of `tf.distribute.CrossDeviceOps`. 369 370 Overriding this method is useful for subclass implementers. 371 372 Args: 373 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` 374 like object. 375 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 376 `tf.Tensor` alike object, or a device string. It specifies the devices 377 to gather to. To perform an all-gather, pass the same to `value` and 378 `destinations`. Note that if it's a `tf.Variable`, the value is gathered 379 to the devices of that variable, this method doesn't update the 380 variable. 381 axis: specifies the dimension to gather along within each replica's 382 tensor. 383 options: a `tf.distribute.experimental.CommunicationOptions`. See 384 `tf.distribute.experimental.CommunicationOptions` for details. 385 386 Returns: 387 A `tf.Tensor` or `tf.distribute.DistributedValues`. 388 389 Raises: 390 ValueError: if per_replica_value can't be converted to a 391 `tf.distribute.DistributedValues` or if destinations is not a string, 392 `tf.Variable` or `tf.distribute.DistributedValues`. 393 """ 394 raise NotImplementedError( 395 "_gather method must be implemented in descendants.") 396 397 def batch_reduce(self, reduce_op, value_destination_pairs, options=None): 398 """Reduce values to destinations in batches. 399 400 See `tf.distribute.StrategyExtended.batch_reduce_to`. This can only be 401 called in the cross-replica context. 402 403 Args: 404 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be 405 combined. 406 value_destination_pairs: a sequence of (value, destinations) pairs. See 407 `tf.distribute.CrossDeviceOps.reduce` for descriptions. 408 options: a `tf.distribute.experimental.CommunicationOptions`. See 409 `tf.distribute.experimental.CommunicationOptions` for details. 410 411 Returns: 412 A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair 413 in `value_destination_pairs`. 414 415 Raises: 416 ValueError: if `value_destination_pairs` is not an iterable of 417 tuples of `tf.distribute.DistributedValues` and destinations. 418 """ 419 if options is None: 420 options = collective_util.Options() 421 # TODO(yuefengz): if destinations are different, split into several 422 # `_batch_reduce` invocations. 423 if not _validate_value_destination_pairs(value_destination_pairs): 424 # If the first element of each pair is a tensor, we try to turn it into a 425 # PerReplica object. 426 value_destination_pairs = _normalize_value_destination_pairs( 427 value_destination_pairs) 428 429 for _, d in value_destination_pairs: 430 validate_destinations(d) 431 432 # Shortcut all PerReplica objects only contain one value. 433 if self._num_between_graph_workers == 1 and _all_devices_match( 434 value_destination_pairs, self._canonicalize_devices) and len( 435 value_destination_pairs[0][0].values) == 1: 436 return [ 437 distribute_utils.regroup(v.values, wrap_class=value_lib.Mirrored) 438 for v, _ in value_destination_pairs 439 ] 440 441 if options is None: 442 options = collective_util.Options() 443 return self.batch_reduce_implementation(reduce_op, value_destination_pairs, 444 options) 445 446 def broadcast(self, tensor, destinations): 447 """Broadcast `tensor` to `destinations`. 448 449 This can only be called in the cross-replica context. 450 451 Args: 452 tensor: a `tf.Tensor` like object. The value to broadcast. 453 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 454 `tf.Tensor` alike object, or a device string. It specifies the devices 455 to broadcast to. Note that if it's a `tf.Variable`, the value is 456 broadcasted to the devices of that variable, this method doesn't update 457 the variable. 458 459 Returns: 460 A `tf.Tensor` or `tf.distribute.DistributedValues`. 461 """ 462 validate_destinations(destinations) 463 return self.broadcast_implementation(tensor, destinations) 464 465 @doc_controls.for_subclass_implementers 466 def reduce_implementation(self, reduce_op, per_replica_value, destinations, 467 options): 468 """Implementation of `reduce`. 469 470 Overriding this method is useful for subclass implementers. 471 472 Args: 473 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be 474 combined. 475 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` 476 like object. 477 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 478 `tf.Tensor` alike object, or a device string. It specifies the devices 479 to reduce to. To perform an all-reduce, pass the same to `value` and 480 `destinations`. Note that if it's a `tf.Variable`, the value is reduced 481 to the devices of that variable, this method doesn't update the 482 variable. 483 options: a `tf.distribute.experimental.CommunicationOptions`. See 484 `tf.distribute.experimental.CommunicationOptions` for details. 485 486 Returns: 487 A `tf.Tensor` or `tf.distribute.DistributedValues`. 488 489 Raises: 490 ValueError: if per_replica_value can't be converted to a 491 `tf.distribute.DistributedValues` or if destinations is not a string, 492 `tf.Variable` or `tf.distribute.DistributedValues`. 493 """ 494 raise NotImplementedError( 495 "_reduce method must be implemented in descendants.") 496 497 @doc_controls.for_subclass_implementers 498 def batch_reduce_implementation(self, reduce_op, value_destination_pairs, 499 options): 500 """Implementation of `batch_reduce`. 501 502 Overriding this method is useful for subclass implementers. 503 504 Args: 505 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be 506 combined. 507 value_destination_pairs: a sequence of (value, destinations) pairs. See 508 `reduce` for descriptions. 509 options: a `tf.distribute.experimental.CommunicationOptions`. See 510 `tf.distribute.experimental.CommunicationOptions` for details. 511 512 Returns: 513 A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair 514 in `value_destination_pairs`. 515 516 Raises: 517 ValueError: if `value_destination_pairs` is not an iterable of 518 tuples of `tf.distribute.DistributedValues` and destinations. 519 """ 520 raise NotImplementedError( 521 "batch_reduce_implementation method must be implemented in descendants." 522 ) 523 524 @doc_controls.for_subclass_implementers 525 def broadcast_implementation(self, tensor, destinations): 526 """Implementation of `broadcast`. 527 528 Args: 529 tensor: a `tf.Tensor` like object. The value to broadcast. 530 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 531 `tf.Tensor` alike object, or a device string. It specifies the devices 532 to broadcast to. 533 `destinations`. Note that if it's a `tf.Variable`, the value is 534 broadcasted to the devices of that variable, this method doesn't update 535 the variable. 536 537 Returns: 538 A `tf.Tensor` or `tf.distribute.DistributedValues`. 539 """ 540 return simple_broadcast( 541 tensor, 542 destinations, 543 always_mirrored=True, 544 canonicalize_devices=self._canonicalize_devices) 545 546 # ========================== Collective APIs ================================ 547 # 548 # Different than `reduce`, `batch_reduce` and `broadcast` which must be called 549 # in cross-replcia context, collective APIs are to be called in replica 550 # context. 551 552 def _all_reduce(self, reduce_op, value, replica_id, options): 553 """All-reduce the `value` across all replicas so that all get the result. 554 555 `value` can be a nested structure of tensors or `IndexedSlices`. The 556 implementation should generally batch the all-reduces when possible. 557 `options` can be set to hint the batching behavior. 558 559 This API must be called in a replica context. 560 561 Args: 562 reduce_op: A `tf.distribute.ReduceOp` value specifying how values should 563 be combined. 564 value: Value to be reduced. A tensor or a nested structure of tensors or 565 `IndexedSlices`. 566 replica_id: An interger indicating the id of the replica where this 567 all_reduce is called under. This is the local replica id that ranges 568 from 0 to len(local_devices) - 1. 569 options: A `tf.distribute.experimental.CommunicationOptions`. 570 571 Returns: 572 A tensor/IndexedSlices or a nested strucutre of tensors/IndexedSlices with 573 the reduced values. The structure is the same as `value`. 574 """ 575 raise NotImplementedError("_all_reduce must be implemented in descendants.") 576 577 578@tf_export("distribute.ReductionToOneDevice") 579class ReductionToOneDevice(CrossDeviceOps): 580 """A CrossDeviceOps implementation that copies values to one device to reduce. 581 582 This implementation always copies values to one device to reduce them, then 583 broadcast reduced values to the destinations. It doesn't support efficient 584 batching. 585 586 Here is how you can use `ReductionToOneDevice` in 587 `tf.distribute.MirroredStrategy`: 588 589 ``` 590 strategy = tf.distribute.MirroredStrategy( 591 cross_device_ops=tf.distribute.ReductionToOneDevice()) 592 ``` 593 """ 594 595 def __init__(self, reduce_to_device=None, accumulation_fn=None): 596 """Initializes with a device to reduce to and a way to accumulate. 597 598 Args: 599 reduce_to_device: the intermediate device to reduce to. If None, reduce 600 to the first device in `destinations` of the `reduce` method. 601 accumulation_fn: a function that does accumulation. If None, 602 `tf.math.add_n` is used. 603 """ 604 self.reduce_to_device = reduce_to_device 605 self.accumulation_fn = accumulation_fn or math_ops.add_n 606 super(ReductionToOneDevice, self).__init__() 607 608 def reduce_implementation(self, reduce_op, per_replica_value, destinations, 609 options): 610 del options # Unused. 611 if check_destinations(destinations): 612 devices = get_devices_from(destinations, self._canonicalize_devices) 613 else: 614 devices = get_devices_from(per_replica_value, self._canonicalize_devices) 615 reduce_to_device = self.reduce_to_device or devices[0] 616 logging.log_first_n( 617 logging.INFO, 618 "Reduce to %s then broadcast to %r." % (reduce_to_device, devices), 10) 619 reduced = _simple_reduce(per_replica_value, reduce_to_device, 620 self.accumulation_fn, reduce_op) 621 return self.broadcast(reduced, destinations) 622 623 def _gather_implementation(self, per_replica_value, destinations, axis, 624 options): 625 del options # Unused. 626 if check_destinations(destinations): 627 devices = get_devices_from(destinations, self._canonicalize_devices) 628 else: 629 devices = get_devices_from(per_replica_value, self._canonicalize_devices) 630 reduce_to_device = self.reduce_to_device or devices[0] 631 logging.log_first_n( 632 logging.INFO, 633 "Gather to %s then broadcast to %r." % (reduce_to_device, devices), 10) 634 gathered = _simple_gather(per_replica_value, reduce_to_device, axis) 635 return self.broadcast(gathered, destinations) 636 637 def batch_reduce_implementation(self, reduce_op, value_destination_pairs, 638 options): 639 return [ 640 self.reduce_implementation( 641 reduce_op, t, destinations=v, options=options) 642 for t, v in value_destination_pairs 643 ] 644 645 646def _group_value_by_device(per_replica_values): 647 """Group values into sublists by their devices. 648 649 This grouping is needed to call the all-reduce library because it expects a 650 list of the following form: 651 [[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...], 652 [(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...], 653 [(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...], 654 ... 655 ] 656 657 Args: 658 per_replica_values: a list of PerReplica objects. 659 660 Returns: 661 a list of lists, each sublist has components for its corresponding device of 662 PerReplica objects, paired with a None. 663 """ 664 destinations = per_replica_values[0]._devices # pylint: disable=protected-access 665 grouped = [[] for _ in range(len(destinations))] 666 for per_replica_value in per_replica_values: 667 # pylint: disable=protected-access 668 for i, v in enumerate(per_replica_value.values): 669 assert per_replica_value._devices == destinations 670 grouped[i].append((v, None)) 671 return grouped 672 673 674def _ungroup_and_make_mirrored(grouped_reduced, 675 destinations, 676 reduce_op, 677 num_between_graph_workers=1): 678 """Ungroup results from all-reduce and make Mirrored objects. 679 680 Each all-reduce result will be divided by the number of destinations before 681 Mirrored objects are created if reduce_op is "mean". 682 683 Args: 684 grouped_reduced: a list of lists, each sublist has components for each 685 device, paired with a None. It is the result from 686 cross_device_utils.aggregate_gradients_using*. 687 destinations: a value to colocate the result with. 688 reduce_op: Indicates how values will be aggregated. Accepted values 689 are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`. 690 num_between_graph_workers: number of workers in the between-graph 691 replication. 692 693 Returns: 694 a list of Mirrored objects. 695 """ 696 num_replicas = len(get_devices_from(destinations)) * num_between_graph_workers 697 index = [[] for _ in range(len(grouped_reduced[0]))] 698 for per_replica_reduced in grouped_reduced: 699 for i, (v, _) in enumerate(per_replica_reduced): 700 if reduce_op == reduce_util.ReduceOp.MEAN: 701 with ops.device(v.device): 702 index[i].append(v / num_replicas) 703 else: 704 index[i].append(v) 705 return [distribute_utils.regroup( 706 v, wrap_class=value_lib.Mirrored) for v in index] 707 708 709class _ConcatAndSplitPacker(object): 710 """Concatenate and split tensors for reduction.""" 711 712 def __init__(self, num_packs=1): 713 """Initialize the _ConcatAndSplitPacker object. 714 715 Args: 716 num_packs: specifies the number of split packs that will be 717 formed. 718 719 Raises: 720 ValueError: if num_packs is not greater than 0. 721 """ 722 if num_packs <= 0: 723 raise ValueError("num_packs must be greater than zero.") 724 self.num_packs = num_packs 725 726 def pack(self, grouped_grads_and_vars): 727 """Pack tensors.""" 728 self.grouped_grads_and_vars = grouped_grads_and_vars 729 self.all_device_shapes = [] 730 self.all_device_sizes = [] 731 732 device_grad_packs = [] 733 for device_grads_and_vars in grouped_grads_and_vars: 734 with ops.colocate_with(device_grads_and_vars[0][0]): 735 # Flatten all the grads. 736 flat_grads = [ 737 array_ops.reshape(g, [-1]) for g, _ in device_grads_and_vars 738 ] 739 # Remember the original shape of all the grads. 740 device_shapes = [array_ops.shape(g) for g, _ in device_grads_and_vars] 741 # Remember the original sizes of all the grads. 742 device_sizes = [array_ops.size(g) for g, _ in device_grads_and_vars] 743 # Concat all the flat grads into a big flat tensor. 744 concat_grads = array_ops.concat(flat_grads, 0) 745 746 # Split the big tensor into num_splits packs. In cases where the 747 # total size is not divisible num_splits, the last pack gets 748 # more elements. 749 # TODO(zhengxq): it is also possible to optimize away all the concat 750 # as well. 751 num_splits = self.num_packs 752 753 # The array_ops.size function will sometimes remove static shapes. So if 754 # all gradient shapes are defined, we use another method to get the 755 # total size. 756 # TODO(yuefengz): move this logic to array_ops.size. 757 if all(g.shape.is_fully_defined() for g, _ in device_grads_and_vars): 758 total_grad_size = sum( 759 [g.shape.num_elements() for g, _ in device_grads_and_vars]) 760 else: 761 total_grad_size = array_ops.size(concat_grads) 762 763 split_size = total_grad_size // num_splits 764 split_size_last = total_grad_size - split_size * (num_splits - 1) 765 split_sizes = [split_size] * (num_splits - 1) + [split_size_last] 766 grad_packs = array_ops.split(concat_grads, split_sizes) 767 768 # Ready to aggregate the repacked gradients, with fake variables. 769 # TODO(zhengxq): It is hacky to have to use fake variables. 770 # We should remove the need for variables in 771 # aggregate_gradients_using*. 772 device_grad_packs.append(zip(grad_packs, [None] * num_splits)) 773 self.all_device_shapes.append(device_shapes) 774 self.all_device_sizes.append(device_sizes) 775 776 return device_grad_packs 777 778 def unpack(self, summed_device_grad_packs): 779 """Reverse the pack.""" 780 aggregated_device_grads = [] 781 for (summed_device_grad_packs, 782 device_grads_and_vars, device_shapes, device_sizes) in zip( 783 summed_device_grad_packs, self.grouped_grads_and_vars, 784 self.all_device_shapes, self.all_device_sizes): 785 # pylint: enable=line-too-long 786 # Reverse the packing operations in the previous steps. Form the 787 # summed gradients back into their original shapes. 788 with ops.colocate_with(summed_device_grad_packs[0][0]): 789 # Form a list of the summed grad packs. 790 device_grad_packs = [g for g, _ in summed_device_grad_packs] 791 792 # Concat them back into a big flat tensor. 793 device_grads_concat = array_ops.concat(device_grad_packs, 0) 794 795 # Split the tensors back into their original sizes. 796 grads_with_sizes = array_ops.split(device_grads_concat, device_sizes) 797 798 # Reshape the tensors back into their original shapes. 799 grads_with_shapes = [ 800 array_ops.reshape(grad, shape) 801 for shape, grad in zip(device_shapes, grads_with_sizes) 802 ] 803 804 # Form the list with the original list of variables. 805 summed_device_grads = [ 806 (g, v) for g, (_, v) in zip(grads_with_shapes, 807 device_grads_and_vars) 808 ] 809 aggregated_device_grads.append(summed_device_grads) 810 return aggregated_device_grads 811 812 813def _pack_tensors(device_grads, num_packs=0): 814 """Pack tensors if specified.""" 815 if num_packs > 0: 816 tensor_packer = _ConcatAndSplitPacker(num_packs) 817 device_grad_packs = tensor_packer.pack(device_grads) 818 else: 819 tensor_packer = None 820 device_grad_packs = device_grads 821 return device_grad_packs, tensor_packer 822 823 824def _unpack_tensors(reduced, tensor_packer=None): 825 """Unpack tensors if they are packed before all-reduce.""" 826 if tensor_packer: 827 return tensor_packer.unpack(reduced) 828 return reduced 829 830 831class AllReduceCrossDeviceOps(CrossDeviceOps): 832 """All-reduce implementation of CrossDeviceOps. 833 834 It performs all-reduce when applicable using NCCL or hierarchical copy. For 835 the batch API, tensors will be repacked or aggregated for more efficient 836 cross-device transportation. 837 838 For reduces that are not all-reduce, it falls back to 839 `tf.distribute.ReductionToOneDevice`. 840 """ 841 842 def __init__(self, all_reduce_alg="nccl", num_packs=1): 843 """Initializes the object. 844 845 Args: 846 all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or 847 "hierarchical_copy" are supported. 848 num_packs: a non-negative integer. The number of packs to split values 849 into. If zero, no packing will be done. 850 """ 851 self._all_reduce_alg = all_reduce_alg 852 self._num_packs = num_packs 853 self._simple_cross_replica_ops = ReductionToOneDevice() 854 super(AllReduceCrossDeviceOps, self).__init__() 855 856 def reduce_implementation(self, reduce_op, per_replica_value, destinations, 857 options): 858 del options # Unused. 859 # To use NCCL or all-reduce, source and destination devices should match, 860 # and none of the devices should be CPU. 861 if (_devices_match(per_replica_value, destinations) and 862 not any("cpu" in d.lower() for d in get_devices_from(destinations))): 863 return self._batch_all_reduce(reduce_op, [per_replica_value])[0] 864 else: 865 return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value, 866 destinations) 867 868 def batch_reduce_implementation(self, reduce_op, value_destination_pairs, 869 options): 870 if _all_devices_match(value_destination_pairs): 871 return self._batch_all_reduce(reduce_op, 872 [v[0] for v in value_destination_pairs]) 873 else: 874 return [ 875 self.reduce_implementation(reduce_op, value, dest, options) 876 for value, dest in value_destination_pairs 877 ] 878 879 def _batch_all_reduce(self, reduce_op, per_replica_values): 880 """All-reduce algorithm in a batch.""" 881 dense_values, dense_indices, sparse_values, sparse_indices = ( 882 cross_device_utils.split_by_sparsity(per_replica_values)) 883 if dense_values: 884 dense_results = self._do_batch_all_reduce(reduce_op, dense_values) 885 else: 886 dense_results = [] 887 if sparse_values: 888 sparse_results = self._do_batch_all_reduce_sparse(reduce_op, 889 sparse_values) 890 else: 891 sparse_results = [] 892 return cross_device_utils.stitch_values(((dense_results, dense_indices), 893 (sparse_results, sparse_indices))) 894 895 def _do_batch_all_reduce(self, reduce_op, dense_values): 896 """Run batch all-reduces.""" 897 logging.log_first_n( 898 logging.INFO, 899 "batch_all_reduce: %d all-reduces with algorithm = %s, num_packs = %d" % 900 (len(dense_values), self._all_reduce_alg, self._num_packs), 10) 901 902 destinations = dense_values[0]._devices # pylint: disable=protected-access 903 grouped = _group_value_by_device(dense_values) 904 905 # device_grad_packs: 906 # [[(t0_gpu0, None), (t1_gpu0, None)], [(t0_gpu1, None), (t1_gpu1, None)]] 907 device_grad_packs, tensor_packer = _pack_tensors(grouped, self._num_packs) 908 909 # The actual aggregation of the repacked gradients. Note that they are 910 # sharded among different aggregation trees. So it is important to strike 911 # the balance on num_splits. 912 if self._all_reduce_alg == "nccl": 913 # TODO(yuefengz): merge this into the all-reduce library. 914 reduced = cross_device_utils.aggregate_gradients_using_nccl( 915 device_grad_packs) 916 else: 917 # TODO(yuefengz): check that gpu ids in `destinations` are in ascending 918 # order. 919 reduced = ( 920 cross_device_utils.aggregate_gradients_using_hierarchical_copy( 921 destinations, device_grad_packs)) 922 923 reduced = _unpack_tensors(reduced, tensor_packer) 924 return _ungroup_and_make_mirrored(reduced, dense_values[0], reduce_op) 925 926 def _do_batch_all_reduce_sparse(self, reduce_op, sparse_values): 927 """Run batch all-reduce for sparse values.""" 928 logging.log_first_n( 929 logging.WARN, 930 "Efficient allreduce is not supported for %d IndexedSlices" % 931 len(sparse_values), 10) 932 # Use `sparse_values` as destinations to do all-reduces. It is effectively 933 # an allgather under the hood but not an efficient one. 934 return self._simple_cross_replica_ops.batch_reduce( 935 reduce_op, zip(sparse_values, sparse_values)) 936 937 def _gather_implementation(self, per_replica_value, destinations, axis, 938 options): 939 logging.log_first_n( 940 logging.WARN, 941 "gather/all_gather with NCCL or HierarchicalCopy is not supported. " 942 "Falling back to gather on one device and then broadcast. We're working" 943 " on a more efficient implementation.", 3) 944 return ReductionToOneDevice()._gather(per_replica_value, destinations, axis, # pylint: disable=protected-access 945 options) 946 947 948# For compatibility with code using the old name of `AllReduceCrossDeviceOps`. 949AllReduceCrossTowerOps = AllReduceCrossDeviceOps 950 951 952AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple", 953 "alg shards limit") 954 955 956@tf_export("distribute.NcclAllReduce") 957class NcclAllReduce(AllReduceCrossDeviceOps): 958 """NCCL all-reduce implementation of CrossDeviceOps. 959 960 It uses Nvidia NCCL for all-reduce. For the batch API, tensors will be 961 repacked or aggregated for more efficient cross-device transportation. 962 963 For reduces that are not all-reduce, it falls back to 964 `tf.distribute.ReductionToOneDevice`. 965 966 Here is how you can use `NcclAllReduce` in `tf.distribute.MirroredStrategy`: 967 968 969 ``` 970 strategy = tf.distribute.MirroredStrategy( 971 cross_device_ops=tf.distribute.NcclAllReduce()) 972 ``` 973 """ 974 975 def __init__(self, num_packs=1): 976 """Initializes the object. 977 978 Args: 979 num_packs: a non-negative integer. The number of packs to split values 980 into. If zero, no packing will be done. 981 982 Raises: 983 ValueError: if `num_packs` is negative. 984 """ 985 if num_packs < 0: 986 raise ValueError( 987 "NCCL all-reduce requires num_packs >= 0, but {} is specified".format( 988 num_packs)) 989 super(NcclAllReduce, self).__init__( 990 all_reduce_alg="nccl", num_packs=num_packs) 991 992 993@tf_export("distribute.HierarchicalCopyAllReduce") 994class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps): 995 """Hierarchical copy all-reduce implementation of CrossDeviceOps. 996 997 It reduces to one GPU along edges in some hierarchy and broadcasts back to 998 each GPU along the same path. For the batch API, tensors will be repacked or 999 aggregated for more efficient cross-device transportation. 1000 1001 This is a reduction created for Nvidia DGX-1 which assumes GPUs connects like 1002 that on DGX-1 machine. If you have different GPU inter-connections, it is 1003 likely that it would be slower than `tf.distribute.ReductionToOneDevice`. 1004 1005 For reduces that are not all-reduce, it falls back to 1006 `tf.distribute.ReductionToOneDevice`. 1007 1008 Here is how you can use `HierarchicalCopyAllReduce` in 1009 `tf.distribute.MirroredStrategy`: 1010 1011 ``` 1012 strategy = tf.distribute.MirroredStrategy( 1013 cross_device_ops=tf.distribute.HierarchicalCopyAllReduce()) 1014 ``` 1015 """ 1016 1017 def __init__(self, num_packs=1): 1018 """Initializes the object. 1019 1020 Args: 1021 num_packs: a non-negative integer. The number of packs to split values 1022 into. If zero, no packing will be done. 1023 1024 Raises: 1025 ValueError if `num_packs` is negative. 1026 """ 1027 if num_packs < 0: 1028 raise ValueError( 1029 "HierarchicalCopy requires num_packs >= 0, but {} is specified" 1030 .format(num_packs)) 1031 super(HierarchicalCopyAllReduce, self).__init__( 1032 all_reduce_alg="hierarchical_copy", 1033 num_packs=num_packs) 1034 1035 1036# TODO(crccw): remove after migrating all callers. 1037CollectiveCommunication = collective_util.CommunicationImplementation 1038CommunicationImplementation = collective_util.CommunicationImplementation 1039 1040 1041# TODO(yuefengz): support in-graph collective all-reduce. 1042class CollectiveAllReduce(CrossDeviceOps): 1043 """All-reduce cross device ops using collective ops. 1044 1045 In the between-graph replicated training, it will still do all-reduces across 1046 all workers and then put results on the right destinations. 1047 """ 1048 1049 def __init__(self, 1050 devices, 1051 group_size, 1052 options, 1053 collective_keys=None, 1054 canonicalize_devices=True): 1055 """Initializes the object. 1056 1057 Args: 1058 devices: a list of device strings to run collectives on. 1059 group_size: the global group size. For between-graph replicated training 1060 it's the total number of devices across all workers. 1061 options: a `tf.distribute.experimental.CommunicationOptions`. 1062 collective_keys: an optional CollectiveKey object. 1063 canonicalize_devices: Whether to canonicalize devices for workers or not. 1064 """ 1065 if group_size % len(devices) > 0: 1066 raise ValueError("group_size must be divisible by the number of devices.") 1067 1068 self._group_size = group_size 1069 self._options = options 1070 self._collective_keys = (collective_keys or 1071 cross_device_utils.CollectiveKeys()) 1072 # This lock guards all collective launches, i.e. calls to 1073 # cross_device_utils.build_collectve_*. 1074 # 1075 # In a multi threaded eager program we need to ensure different groups of 1076 # collectives don't interleave each other, otherwise there could be 1077 # deadlocks. E.g. if two user threads both are launching collectives: 1078 # user-thread-0 device0 device1 1079 # user-thread-1 device0 device1 1080 # In eager mode, we use one thread per device to launch collective ops, so 1081 # the above launch sequences end up with the following queues: 1082 # device-0 collective-0 collective-1 1083 # device-1 collective-1 collective-0 1084 # This deadlocks since neither collective is able to finish. 1085 self._lock = threading.Lock() 1086 1087 if canonicalize_devices: 1088 self._devices = tuple(device_util.canonicalize(d) for d in devices) 1089 else: 1090 self._devices = tuple( 1091 device_util.canonicalize_without_job_and_task(d) for d in devices) 1092 group_key = self._collective_keys.get_group_key(self._devices) 1093 self._launchers = [] 1094 # Whether to only use NCCL for batched all-reduce when NCCL is requested. 1095 # This is because of the lack of mechanism to order NCCL operations 1096 # deterministically. 1097 self._limited_nccl = False 1098 for device in self._devices: 1099 launcher = cross_device_utils.CollectiveReplicaLauncher( 1100 group_key, group_size, self._collective_keys, device, options) 1101 self._launchers.append(launcher) 1102 if not launcher.can_order_nccl(): 1103 self._limited_nccl = True 1104 1105 super(CollectiveAllReduce, self).__init__() 1106 self._canonicalize_devices = canonicalize_devices 1107 1108 @property 1109 def _num_between_graph_workers(self): 1110 # Currently we only support equal number of devices on each worker. 1111 return self._group_size / len(self._devices) 1112 1113 def _all_reduce(self, reduce_op, value, replica_id, options): 1114 """Implements CrossDeviceOps.all_reduce.""" 1115 # TODO(b/122840926): reuse this method in _batch_all_reduce. 1116 flat_values = nest.flatten(value) 1117 1118 # If NCCL launches can't be ordered (self._limited_nccl == True), we only 1119 # use NCCL when batch_size > 1, hoping that there's only one batched 1120 # all-reduce, which is the gradient aggregation in optimizer. For TF 2.x, 1121 # NCCL launches are always ordered. 1122 if (self._limited_nccl and options.implementation 1123 == collective_util.CommunicationImplementation.NCCL and 1124 len(flat_values) == 1): 1125 options = options.merge( 1126 collective_util.Options( 1127 implementation=collective_util.CommunicationImplementation.RING)) 1128 1129 launcher = self._launchers[replica_id] 1130 dense_values, dense_indices, sparse_values, sparse_indices = ( 1131 cross_device_utils.split_by_sparsity(flat_values)) 1132 dense_results = [] 1133 sparse_results = [] 1134 1135 if dense_values: 1136 # Reverse the lists so that there's better chance that values follows 1137 # the order in which they are calculated (e.g. when they're gradients), so 1138 # as to overlap calculation with communication. However, this may not be 1139 # optimal for cases like gradients of complicated non-sequential models. 1140 # 1141 # Note that we reverse the list before packing so that the first pack 1142 # won't be too small, since it's more likely for first few packs to have 1143 # long queuing time due to concurrent intense computation. 1144 # 1145 # TODO(b/147393503): explore solutions for optimal ordering. 1146 dense_values.reverse() 1147 packs = cross_device_utils.group_by_size(dense_values, 1148 options.bytes_per_pack) 1149 1150 if not context.executing_eagerly() and replica_id == 0: 1151 logging.info( 1152 "Collective all_reduce tensors: %d all_reduces, num_devices = %d, " 1153 "group_size = %d, implementation = %s, num_packs = %d", 1154 len(dense_values), len(self._launchers), self._group_size, 1155 options.implementation, len(packs)) 1156 1157 dense_results = launcher.batch_all_reduce(packs, options) 1158 if reduce_op == reduce_util.ReduceOp.MEAN: 1159 for i, v in enumerate(dense_results): 1160 with ops.device(self._devices[replica_id]): 1161 dense_results[i] = v / self._group_size 1162 dense_results.reverse() 1163 1164 if sparse_values: 1165 if not context.executing_eagerly() and replica_id == 0: 1166 logging.info( 1167 "Collective all_reduce IndexedSlices: %d all_reduces, num_devices =" 1168 "%d, group_size = %d, implementation = %s", len(sparse_values), 1169 len(self._launchers), self._group_size, options.implementation) 1170 1171 for indexed_slice in sparse_values: 1172 sparse_results.append( 1173 launcher.all_reduce_indexed_slices(indexed_slice, options)) 1174 1175 if reduce_op == reduce_util.ReduceOp.MEAN: 1176 for i, v in enumerate(sparse_results): 1177 with ops.device(self._devices[replica_id]): 1178 sparse_results[i] = indexed_slices.IndexedSlices( 1179 values=sparse_results[i].values / self._group_size, 1180 indices=sparse_results[i].indices, 1181 dense_shape=sparse_results[i].dense_shape) 1182 1183 flat_results = cross_device_utils.stitch_values( 1184 ((dense_results, dense_indices), (sparse_results, sparse_indices))) 1185 return nest.pack_sequence_as(value, flat_results) 1186 1187 def _all_reduce_per_replica_values(self, reduce_op, per_replica_values, 1188 options): 1189 """All reduce a list of per_replica_value.""" 1190 values_by_device = [[] for _ in self._devices] 1191 num_devices = len(self._devices) 1192 for per_replica in per_replica_values: 1193 for i in range(num_devices): 1194 values_by_device[i].append(per_replica.values[i]) 1195 1196 if context.executing_eagerly(): 1197 1198 def thread_fn(device_id): 1199 with context.eager_mode(): 1200 return self._all_reduce(reduce_op, values_by_device[device_id], 1201 device_id, options) 1202 1203 with self._lock: 1204 pool = multiprocessing.pool.ThreadPool(len(self._devices)) 1205 outputs_by_device = pool.map(thread_fn, list(range(num_devices))) 1206 pool.close() 1207 else: 1208 outputs_by_device = [] 1209 with self._lock: 1210 for i in range(num_devices): 1211 outputs_by_device.append( 1212 self._all_reduce(reduce_op, values_by_device[i], i, options)) 1213 1214 result = [] 1215 for values in zip(*outputs_by_device): 1216 result.append( 1217 distribute_utils.regroup(values, wrap_class=value_lib.Mirrored)) 1218 return result 1219 1220 def reduce_implementation(self, reduce_op, per_replica_value, destinations, 1221 options): 1222 values_util.mark_as_unsaveable() 1223 all_reduced = self._all_reduce_per_replica_values(reduce_op, 1224 [per_replica_value], 1225 options)[0] 1226 devices = get_devices_from(destinations, self._canonicalize_devices) 1227 1228 if _devices_match(per_replica_value, destinations, 1229 self._canonicalize_devices): 1230 return all_reduced 1231 1232 # Convert `all_reduced` to a `Mirrored` object, as a simple and uniform 1233 # utility to access component for a particular device. 1234 if not isinstance(all_reduced, value_lib.Mirrored): 1235 all_reduced = value_lib.Mirrored([all_reduced]) 1236 1237 # If we got this far, the destination devices do not match the all-reduce 1238 # devices, so we must map from one to the other. 1239 index = [] 1240 # We must add these control dependencies, otherwise we can get deadlock. 1241 with ops.control_dependencies(all_reduced.values): 1242 for d in devices: 1243 with ops.device(d): 1244 for v in all_reduced.values: 1245 if v.device == d: 1246 index.append(array_ops.identity(v)) 1247 break 1248 else: 1249 # TODO(josh11b): Once we add support for model parallelism, get the 1250 # copy from the corresponding replica instead of the primary. 1251 index.append(array_ops.identity(all_reduced._primary)) # pylint: disable=protected-access 1252 return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored) 1253 1254 def batch_reduce_implementation(self, reduce_op, value_destination_pairs, 1255 options): 1256 values_util.mark_as_unsaveable() 1257 all_devices_match = _all_devices_match(value_destination_pairs, 1258 self._canonicalize_devices) 1259 if all_devices_match: 1260 return self._all_reduce_per_replica_values( 1261 reduce_op, [v[0] for v in value_destination_pairs], options) 1262 else: 1263 if not all_devices_match: 1264 logging.log_first_n( 1265 logging.WARN, "Efficient batch_reduce is not supported if " 1266 "destinations are different.", 10) 1267 1268 return [ 1269 self.reduce_implementation(reduce_op, value, dest, options) 1270 for value, dest in value_destination_pairs 1271 ] 1272 1273 def _gather_implementation(self, per_replica_value, destinations, axis, 1274 options): 1275 all_gathered = self._batch_all_gather([per_replica_value], axis, options)[0] 1276 values_util.mark_as_unsaveable() 1277 devices = get_devices_from(destinations, self._canonicalize_devices) 1278 1279 if _devices_match(per_replica_value, destinations, 1280 self._canonicalize_devices): 1281 return all_gathered 1282 1283 # Convert `all_gathered` to a `Mirrored` object, as a simple and uniform 1284 # utility to access component for a particular device. 1285 if not isinstance(all_gathered, value_lib.Mirrored): 1286 all_gathered = value_lib.Mirrored([all_gathered]) 1287 1288 # If we got this far, the destination devices do not match the all-gather 1289 # devices, so we must map from one to the other. 1290 index = [] 1291 # We must add these control dependencies, otherwise we can get deadlock. 1292 with ops.control_dependencies(all_gathered.values): 1293 for d in devices: 1294 with ops.device(d): 1295 for v in all_gathered.values: 1296 if v.device == d: 1297 index.append(array_ops.identity(v)) 1298 break 1299 else: 1300 index.append(array_ops.identity(all_gathered._primary)) # pylint: disable=protected-access 1301 return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored) 1302 1303 def _batch_all_gather(self, per_replica_values, axis, options): 1304 """all gather multiple per-replica-values.""" 1305 batch_size = len(per_replica_values) 1306 # For now, we use NCCL only when batch_size > 1. 1307 # TODO(b/132575814): switch to NCCL for all collectives when implementation 1308 # is NCCL. 1309 if (self._limited_nccl and options.implementation 1310 == collective_util.CommunicationImplementation.NCCL and 1311 batch_size == 1): 1312 options = options.merge( 1313 collective_util.Options( 1314 implementation=collective_util.CommunicationImplementation.RING)) 1315 1316 logging.log_first_n( 1317 logging.INFO, "Collective batch_all_gather: %d all-gathers, " 1318 "num_devices = %d, group_size = %d, implementation = %s, " % 1319 (batch_size, len( 1320 self._devices), self._group_size, options.implementation), 10) 1321 1322 def compute_gathered_values(): 1323 gathered_values = [] 1324 with self._lock, ops.name_scope("allgather"): 1325 for per_replica in per_replica_values: 1326 outputs = [] 1327 for i in range(len(self._devices)): 1328 outputs.append(self._launchers[i].all_gather( 1329 per_replica.values[i], axis, options)) 1330 gathered_values.append(outputs) 1331 return gathered_values 1332 1333 if context.executing_eagerly(): 1334 gathered_values = def_function.function(compute_gathered_values)() 1335 else: 1336 gathered_values = compute_gathered_values() 1337 1338 mirrored = [] 1339 for value in gathered_values: 1340 mirrored.append( 1341 distribute_utils.regroup(value, wrap_class=value_lib.Mirrored)) 1342 return mirrored 1343 1344 def __deepcopy__(self, memo): 1345 # distribute_coordinator deep-copies the strategy object, so 1346 # CollectiveAllReduce needs to support deep copy as well. 1347 collective_keys = copy.deepcopy(self._collective_keys, memo) 1348 return CollectiveAllReduce(self._devices, self._group_size, self._options, 1349 collective_keys, self._canonicalize_devices) 1350 1351 1352def select_cross_device_ops(devices, session_config=None): 1353 """Find the best `CrossDeviceOps` locally given a `tf.compat.v1.ConfigProto`. 1354 1355 Args: 1356 devices: a list of devices passed to `tf.distribute.Strategy`. 1357 session_config: a `tf.compat.v1.ConfigProto` or `None`. If `None`, it will 1358 make decision based on all logical devices. 1359 1360 Returns: 1361 A subclass of `CrossDeviceOps`. 1362 """ 1363 requested_devices = set(device_util.canonicalize(d) for d in devices) 1364 if ops.executing_eagerly_outside_functions(): 1365 logical_gpus = context.context().list_logical_devices(device_type="GPU") 1366 physical_gpus = context.context().list_physical_devices(device_type="GPU") 1367 if len(logical_gpus) != len(physical_gpus): 1368 logging.warning("NCCL is not supported when using virtual GPUs, falling" 1369 "back to reduction to one device") 1370 return ReductionToOneDevice() 1371 1372 machine_devices = context.context().list_logical_devices() 1373 else: 1374 machine_devices = device_lib.list_local_devices( 1375 session_config=session_config) 1376 using_devices = set() 1377 for d in machine_devices: 1378 if device_util.canonicalize(d.name) in requested_devices: 1379 using_devices.add(d.name) 1380 1381 if len(using_devices) != len(requested_devices): 1382 logging.warning( 1383 "Some requested devices in `tf.distribute.Strategy` are not visible " 1384 "to TensorFlow: %s", ",".join(list(requested_devices - using_devices))) 1385 1386 if any("gpu" not in d.lower() for d in requested_devices): 1387 logging.warning("There are non-GPU devices in `tf.distribute.Strategy`, " 1388 "not using nccl allreduce.") 1389 return ReductionToOneDevice() 1390 1391 if kernels.get_registered_kernels_for_op("NcclAllReduce"): 1392 return NcclAllReduce(num_packs=1) 1393 else: 1394 logging.warning("Nccl kernel is not found, not using nccl allreduce.") 1395 return ReductionToOneDevice() 1396