1# mypy: allow-untyped-defs 2import copy 3import functools 4import inspect 5import itertools 6import logging 7import os 8import sys 9import warnings 10import weakref 11from collections import defaultdict, deque 12from contextlib import contextmanager 13from dataclasses import dataclass, fields, is_dataclass 14from enum import auto, Enum 15from typing import Any, Callable, List, Optional, Tuple, Type, TYPE_CHECKING 16 17import torch 18import torch.distributed as dist 19from torch._utils import _get_device_index 20from torch.autograd import Function, Variable 21from torch.distributed.algorithms.join import Join, Joinable, JoinHook 22from torch.nn.modules import Module 23from torch.nn.parallel.scatter_gather import gather, scatter_kwargs 24from torch.utils._pytree import tree_flatten, tree_unflatten 25 26 27RPC_AVAILABLE = False 28if dist.is_available(): 29 from torch.distributed.distributed_c10d import ( 30 _get_default_group, 31 _rank_not_in_group, 32 ReduceOp, 33 ) 34 from torch.distributed.utils import ( 35 _alloc_storage, 36 _cast_forward_inputs, 37 _free_storage, 38 _sync_module_states, 39 _to_kwargs, 40 _verify_param_shape_across_processes, 41 ) 42if dist.rpc.is_available(): 43 RPC_AVAILABLE = True 44 from torch.distributed.rpc import RRef 45 46if TYPE_CHECKING: 47 from torch.utils.hooks import RemovableHandle 48 49 50__all__ = ["DistributedDataParallel"] 51 52logger = logging.getLogger(__name__) 53 54 55@dataclass 56class _MixedPrecision: 57 """ 58 This configures DDP-native mixed precision training. 59 60 Attributes: 61 param_dtype (torch.dtype): This specifies the dtype for model 62 parameters, inputs (when ``cast_forward_inputs`` is set to 63 ``True``), and therefore the dtype for computation. 64 However, outside the forward and backward passes, parameters are in 65 full precision. Model checkpointing always happens in full 66 precision. 67 reduce_dtype (torch.dtype): This specifies the dtype for gradient 68 reduction, which is permitted to differ from ``param_dtype``. 69 buffer_dtype (torch.dtype): This specifies the dtype for buffers. 70 71 .. note:: This API is experimental and subject to change. 72 73 .. note:: Only floating point tensors are cast to their specified dtypes. 74 75 .. note:: ``state_dict`` checkpoints parameters and buffers in full 76 precision. 77 78 .. note:: Each low precision dtype must be specified explicitly. For 79 example, ``_MixedPrecision(reduce_dtype=torch.float16)`` only specifies 80 the reduction dtype to be low precision, and DDP will not cast 81 parameters or buffers. 82 83 .. note:: If a ``reduce_dtype`` is not specified, then gradient reduction 84 happens in ``param_dtype`` if specified or the original parameter dtype 85 otherwise. For example, ``_MixedPrecision(param_dtype=torch.float16)`` 86 would result in communication occurring in fp16. 87 """ 88 89 param_dtype: Optional[torch.dtype] = None 90 reduce_dtype: Optional[torch.dtype] = None 91 buffer_dtype: Optional[torch.dtype] = None 92 # TODO (rohan-varma): keep_low_precision_grads: bool = False 93 # TODO (rohan-varma): APIs to allow users to run batchnorm and layernorm 94 # in full precision. For DDP, this can be implemented by not performing the 95 # parameter cast for BN and LN units. 96 97 98def _cast_buffers(mixed_precision_config, root_module): 99 """Casts buffers to the given ``buffer_dtype``.""" 100 for buf in root_module.buffers(): 101 if hasattr(buf, "_ddp_ignored") and buf._ddp_ignored: 102 continue 103 104 buf.data = buf.to(dtype=mixed_precision_config.buffer_dtype) 105 106 107def _setup_mixed_precision_params(mixed_precision_config, root_module): 108 """Create and free storage for the mixed precision parameters.""" 109 for param in root_module.parameters(): 110 # Do not setup mixed precision for DDP ignored parameters. 111 if hasattr(param, "_ddp_ignored") and param._ddp_ignored: 112 continue 113 114 if not hasattr(param, "_mp_param"): 115 param._mp_param = torch.zeros_like( 116 param, 117 device=param.device, 118 dtype=mixed_precision_config.param_dtype, 119 requires_grad=param.requires_grad, 120 ) 121 _free_storage(param._mp_param) 122 # _fp_param will point to the full precision param so it can be switched 123 # back to at the end of forward / backward. 124 param._fp_param = param.data 125 126 127def _tree_flatten_with_rref(output): 128 output_is_rref = RPC_AVAILABLE and isinstance(output, RRef) 129 if output_is_rref: 130 output_tensor_list, treespec = tree_flatten(output.local_value()) 131 else: 132 output_tensor_list, treespec = tree_flatten(output) 133 # Need to return flattened tensors, spec to re-pack them, as well 134 # as if the return type was actually an RRef to reconstruct. 135 return output_tensor_list, treespec, output_is_rref 136 137 138def _tree_unflatten_with_rref(output, treespec, output_is_rref): 139 output = tree_unflatten(output, treespec) 140 if output_is_rref: 141 output = RRef(output) 142 return output 143 144 145def _find_tensors(obj): 146 r"""Recursively find all tensors contained in the specified object.""" 147 if RPC_AVAILABLE and isinstance(obj, RRef): 148 # If the current node is the owner of the RRef, unwrap it and try to 149 # find Tensors. 150 # TODO: Expand to remote RRefs. 151 if obj.is_owner(): 152 return _find_tensors(obj.local_value()) 153 if isinstance(obj, torch.Tensor): 154 return [obj] 155 if isinstance(obj, (list, tuple)): 156 return itertools.chain.from_iterable(map(_find_tensors, obj)) 157 if isinstance(obj, dict): 158 return itertools.chain.from_iterable(map(_find_tensors, obj.values())) 159 if is_dataclass(obj): 160 return itertools.chain.from_iterable( 161 map(_find_tensors, (getattr(obj, f.name) for f in fields(obj))) 162 ) 163 164 return [] 165 166 167def _dump_DDP_relevant_env_vars(): 168 relevant_env_vars = [ 169 "RANK", 170 "LOCAL_RANK", 171 "WORLD_SIZE", 172 "MASTER_PORT", 173 "MASTER_ADDR", 174 "CUDA_VISIBLE_DEVICES", 175 "GLOO_SOCKET_IFNAME", 176 "GLOO_DEVICE_TRANSPORT", 177 "NCCL_SOCKET_IFNAME", 178 "TORCH_NCCL_BLOCKING_WAIT", 179 "NCCL_DEBUG", 180 "NCCL_DEBUG_SUBSYS", 181 "NCCL_IB_DISABLE", 182 # More NCCL env vars: 183 "NCCL_P2P_DISABLE", 184 "NCCL_P2P_LEVEL", 185 "NCCL_SHM_DISABLE", 186 "NCCL_SOCKET_NTHREADS", 187 "NCCL_NSOCKS_PERTHREAD", 188 "NCCL_BUFFSIZE", 189 "NCCL_NTHREADS", 190 "NCCL_RINGS", 191 "NCCL_MAX_NCHANNELS", 192 "NCCL_MIN_NCHANNELS", 193 "NCCL_CHECKS_DISABLE", 194 "NCCL_CHECK_POINTERS", 195 "NCCL_LAUNCH_MODE", 196 "NCCL_IB_HCA", 197 "NCCL_IB_TIMEOUT", 198 "NCCL_IB_RETRY_CNT", 199 "NCCL_IB_GID_INDEX", 200 "NCCL_IB_SL", 201 "NCCL_IB_TC", 202 "NCCL_IB_AR_THRESHOLD", 203 "NCCL_IB_CUDA_SUPPORT", 204 "NCCL_NET_GDR_LEVEL", 205 "NCCL_NET_GDR_READ", 206 "NCCL_SINGLE_RING_THRESHOLD", 207 "NCCL_LL_THRESHOLD", 208 "NCCL_TREE_THRESHOLD", 209 "NCCL_ALGO", 210 "NCCL_PROTO", 211 "NCCL_IGNORE_CPU_AFFINITY", 212 "NCCL_DEBUG_FILE", 213 "NCCL_COLLNET_ENABLE", 214 "NCCL_TOPO_FILE", 215 "NCCL_TOPO_DUMP_FILE", 216 "TORCH_NCCL_ASYNC_ERROR_HANDLING", 217 ] 218 formatted_output = "" 219 for var in relevant_env_vars: 220 value = os.environ[var] if var in os.environ else "N/A" 221 formatted_output += f"env:{var}={value}\n" 222 print(formatted_output) 223 224 225class _BufferCommHookLocation(Enum): 226 PRE_FORWARD = auto() 227 POST_FORWARD = auto() 228 229 230@dataclass 231class _BufferCommHook: 232 buffer_comm_hook: Callable 233 buffer_comm_hook_state: Any 234 buffer_comm_hook_location: _BufferCommHookLocation 235 236 237# Add a DDPSink to run various functions when backwards starts, such as 238# queueing call back of out-most backward/graph task, 239# this helps call back is fired after all gradients' calculation 240# is completed. 241class _DDPSink(Function): 242 @staticmethod 243 def forward(ctx, ddp_weakref, *inputs): 244 # set_materialize_grads(False) will ensure that None gradients stay as 245 # None and are not filled with zeros. 246 ctx.set_materialize_grads(False) 247 ctx.ddp_weakref = ddp_weakref 248 ret = inputs 249 if ddp_weakref()._ddp_sink_clone: 250 ret = tuple( 251 inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs 252 ) 253 return ret 254 255 @staticmethod 256 def backward(ctx, *grad_outputs): 257 # Enqueue delay allreduce for static graph training on the first 258 # iteration. 259 ddp_weakref = ctx.ddp_weakref() 260 reducer = ddp_weakref.reducer 261 static_graph = ddp_weakref.static_graph 262 delay_ar_enqueued = ( 263 static_graph and ddp_weakref._static_graph_delay_allreduce_enqueued 264 ) 265 if static_graph and not delay_ar_enqueued: 266 Variable._execution_engine.queue_callback( # type: ignore[call-arg,misc] 267 reducer._delay_all_reduce 268 ) 269 ddp_weakref._static_graph_delay_allreduce_enqueued = True 270 271 return (None, *grad_outputs) 272 273 274class _DDPJoinHook(JoinHook): 275 def __init__(self, ddp, divide_by_initial_world_size): 276 """Set config variables for internal usage.""" 277 assert isinstance(ddp, DistributedDataParallel), ( 278 "DDP join hook requires passing in a DistributedDataParallel " 279 "instance as the state" 280 ) 281 assert ddp.logger is not None 282 ddp.logger._set_uneven_input_join() 283 self.ddp = ddp 284 self.ddp._divide_by_initial_world_size = divide_by_initial_world_size 285 super().__init__() 286 287 def main_hook(self): 288 """Shadow the DDP collective communication operations in the forward and backward passes.""" 289 ddp = self.ddp 290 # Buckets are rebuilt only once during a training period 291 ddp.reducer._rebuild_buckets() 292 293 # Schedule a broadcast if we are syncing module buffers in the 294 # forward pass 295 # TODO: make DDP uneven inputs context manager support buffer 296 # comm hook (https://github.com/pytorch/pytorch/issues/65436) 297 ddp._check_and_sync_module_buffers() 298 299 # Check if need to sync in the backward pass 300 should_sync_backwards = ddp._check_global_requires_backward_grad_sync( 301 is_joined_rank=True 302 ) 303 # Forward parameter sync is disabled in the next iteration if we 304 # are skipping gradient sync this iteration, so set 305 # `require_forward_param_sync` accordingly 306 ddp.require_forward_param_sync = should_sync_backwards 307 if not should_sync_backwards: 308 return 309 310 # Schedule one allreduce per gradient bucket to match the backward 311 # pass allreduce 312 ddp._match_all_reduce_for_bwd_pass() 313 314 # Check if we need to allreduce locally unused parameters 315 if ddp.find_unused_parameters: 316 ddp._match_unused_params_allreduce() 317 318 # Rebuilt parameters are pushed only once during a training period 319 ddp.reducer._push_all_rebuilt_params() 320 321 def post_hook(self, is_last_joiner: bool): 322 """Sync the final model to ensure that the model is the same across all processes.""" 323 self.ddp._sync_final_model(is_last_joiner) 324 325 326class DistributedDataParallel(Module, Joinable): 327 r"""Implement distributed data parallelism based on ``torch.distributed`` at module level. 328 329 This container provides data parallelism by synchronizing gradients 330 across each model replica. The devices to synchronize across are 331 specified by the input ``process_group``, which is the entire world 332 by default. Note that ``DistributedDataParallel`` does not chunk or 333 otherwise shard the input across participating GPUs; the user is 334 responsible for defining how to do so, for example through the use 335 of a :class:`DistributedSampler`. 336 337 See also: :ref:`distributed-basics` and :ref:`cuda-nn-ddp-instead`. 338 The same constraints on input as in :class:`torch.nn.DataParallel` apply. 339 340 Creation of this class requires that ``torch.distributed`` to be already 341 initialized, by calling :func:`torch.distributed.init_process_group`. 342 343 ``DistributedDataParallel`` is proven to be significantly faster than 344 :class:`torch.nn.DataParallel` for single-node multi-GPU data 345 parallel training. 346 347 To use ``DistributedDataParallel`` on a host with N GPUs, you should spawn 348 up ``N`` processes, ensuring that each process exclusively works on a single 349 GPU from 0 to N-1. This can be done by either setting 350 ``CUDA_VISIBLE_DEVICES`` for every process or by calling: 351 352 >>> # xdoctest: +SKIP("undefined variables") 353 >>> torch.cuda.set_device(i) 354 355 where i is from 0 to N-1. In each process, you should refer the following 356 to construct this module: 357 358 >>> # xdoctest: +SKIP("undefined variables") 359 >>> torch.distributed.init_process_group( 360 >>> backend='nccl', world_size=N, init_method='...' 361 >>> ) 362 >>> model = DistributedDataParallel(model, device_ids=[i], output_device=i) 363 364 In order to spawn up multiple processes per node, you can use either 365 ``torch.distributed.launch`` or ``torch.multiprocessing.spawn``. 366 367 .. note:: 368 Please refer to `PyTorch Distributed Overview <https://pytorch.org/tutorials/beginner/dist_overview.html>`__ 369 for a brief introduction to all features related to distributed training. 370 371 .. note:: 372 ``DistributedDataParallel`` can be used in conjunction with 373 :class:`torch.distributed.optim.ZeroRedundancyOptimizer` to reduce 374 per-rank optimizer states memory footprint. Please refer to 375 `ZeroRedundancyOptimizer recipe <https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html>`__ 376 for more details. 377 378 .. note:: ``nccl`` backend is currently the fastest and highly recommended 379 backend when using GPUs. This applies to both single-node and 380 multi-node distributed training. 381 382 .. note:: This module also supports mixed-precision distributed training. 383 This means that your model can have different types of parameters such 384 as mixed types of ``fp16`` and ``fp32``, the gradient reduction on these 385 mixed types of parameters will just work fine. 386 387 .. note:: If you use ``torch.save`` on one process to checkpoint the module, 388 and ``torch.load`` on some other processes to recover it, make sure that 389 ``map_location`` is configured properly for every process. Without 390 ``map_location``, ``torch.load`` would recover the module to devices 391 where the module was saved from. 392 393 .. note:: When a model is trained on ``M`` nodes with ``batch=N``, the 394 gradient will be ``M`` times smaller when compared to the same model 395 trained on a single node with ``batch=M*N`` if the loss is summed (NOT 396 averaged as usual) across instances in a batch (because the gradients 397 between different nodes are averaged). You should take this into 398 consideration when you want to obtain a mathematically equivalent 399 training process compared to the local training counterpart. But in most 400 cases, you can just treat a DistributedDataParallel wrapped model, a 401 DataParallel wrapped model and an ordinary model on a single GPU as the 402 same (E.g. using the same learning rate for equivalent batch size). 403 404 .. note:: 405 Parameters are never broadcast between processes. The module performs 406 an all-reduce step on gradients and assumes that they will be modified 407 by the optimizer in all processes in the same way. Buffers 408 (e.g. BatchNorm stats) are broadcast from the module in process of rank 409 0, to all other replicas in the system in every iteration. 410 411 .. note:: 412 If you are using DistributedDataParallel in conjunction with the 413 :ref:`distributed-rpc-framework`, you should always use 414 :meth:`torch.distributed.autograd.backward` to compute gradients and 415 :class:`torch.distributed.optim.DistributedOptimizer` for optimizing 416 parameters. 417 418 Example:: 419 420 >>> # xdoctest: +SKIP("undefined variables") 421 >>> import torch.distributed.autograd as dist_autograd 422 >>> from torch.nn.parallel import DistributedDataParallel as DDP 423 >>> import torch 424 >>> from torch import optim 425 >>> from torch.distributed.optim import DistributedOptimizer 426 >>> import torch.distributed.rpc as rpc 427 >>> from torch.distributed.rpc import RRef 428 >>> 429 >>> t1 = torch.rand((3, 3), requires_grad=True) 430 >>> t2 = torch.rand((3, 3), requires_grad=True) 431 >>> rref = rpc.remote("worker1", torch.add, args=(t1, t2)) 432 >>> ddp_model = DDP(my_model) 433 >>> 434 >>> # Setup optimizer 435 >>> optimizer_params = [rref] 436 >>> for param in ddp_model.parameters(): 437 >>> optimizer_params.append(RRef(param)) 438 >>> 439 >>> dist_optim = DistributedOptimizer( 440 >>> optim.SGD, 441 >>> optimizer_params, 442 >>> lr=0.05, 443 >>> ) 444 >>> 445 >>> with dist_autograd.context() as context_id: 446 >>> pred = ddp_model(rref.to_here()) 447 >>> loss = loss_func(pred, target) 448 >>> dist_autograd.backward(context_id, [loss]) 449 >>> dist_optim.step(context_id) 450 451 .. note:: 452 DistributedDataParallel currently offers limited support for gradient 453 checkpointing with :meth:`torch.utils.checkpoint`. 454 If the checkpoint is done with use_reentrant=False (recommended), DDP 455 will work as expected without any limitations. 456 If, however, the checkpoint is done with use_reentrant=True (the default), 457 DDP will work as expected when there are no unused parameters in the model 458 and each layer is checkpointed at most once (make sure you are not passing 459 `find_unused_parameters=True` to DDP). We currently do not support the 460 case where a layer is checkpointed multiple times, or when there unused 461 parameters in the checkpointed model. 462 463 .. note:: 464 To let a non-DDP model load a state dict from a DDP model, 465 :meth:`~torch.nn.modules.utils.consume_prefix_in_state_dict_if_present` 466 needs to be applied to strip the prefix "module." in the DDP state dict before loading. 467 468 .. warning:: 469 Constructor, forward method, and differentiation of the output (or a 470 function of the output of this module) are distributed synchronization 471 points. Take that into account in case different processes might be 472 executing different code. 473 474 .. warning:: 475 This module assumes all parameters are registered in the model by the 476 time it is created. No parameters should be added nor removed later. 477 Same applies to buffers. 478 479 .. warning:: 480 This module assumes all parameters are registered in the model of each 481 distributed processes are in the same order. The module itself will 482 conduct gradient ``allreduce`` following the reverse order of the 483 registered parameters of the model. In other words, it is users' 484 responsibility to ensure that each distributed process has the exact 485 same model and thus the exact same parameter registration order. 486 487 .. warning:: 488 This module allows parameters with non-rowmajor-contiguous strides. 489 For example, your model may contain some parameters whose 490 :class:`torch.memory_format` is ``torch.contiguous_format`` 491 and others whose format is ``torch.channels_last``. However, 492 corresponding parameters in different processes must have the 493 same strides. 494 495 .. warning:: 496 This module doesn't work with :func:`torch.autograd.grad` (i.e. it will 497 only work if gradients are to be accumulated in ``.grad`` attributes of 498 parameters). 499 500 .. warning:: 501 If you plan on using this module with a ``nccl`` backend or a ``gloo`` 502 backend (that uses Infiniband), together with a DataLoader that uses 503 multiple workers, please change the multiprocessing start method to 504 ``forkserver`` (Python 3 only) or ``spawn``. Unfortunately 505 Gloo (that uses Infiniband) and NCCL2 are not fork safe, and you will 506 likely experience deadlocks if you don't change this setting. 507 508 .. warning:: 509 You should never try to change your model's parameters after wrapping 510 up your model with ``DistributedDataParallel``. Because, when 511 wrapping up your model with ``DistributedDataParallel``, the constructor 512 of ``DistributedDataParallel`` will register the additional gradient 513 reduction functions on all the parameters of the model itself at the 514 time of construction. If you change the model's parameters afterwards, 515 gradient reduction functions no longer match the correct set of 516 parameters. 517 518 .. warning:: 519 Using ``DistributedDataParallel`` in conjunction with the 520 :ref:`distributed-rpc-framework` is experimental and subject to change. 521 522 Args: 523 module (Module): module to be parallelized 524 device_ids (list of int or torch.device): CUDA devices. 525 1) For single-device modules, ``device_ids`` can 526 contain exactly one device id, which represents the only 527 CUDA device where the input module corresponding to this process resides. 528 Alternatively, ``device_ids`` can also be ``None``. 529 2) For multi-device modules and CPU modules, 530 ``device_ids`` must be ``None``. 531 532 When ``device_ids`` is ``None`` for both cases, 533 both the input data for the forward pass and the actual module 534 must be placed on the correct device. 535 (default: ``None``) 536 output_device (int or torch.device): Device location of output for 537 single-device CUDA modules. For multi-device modules and 538 CPU modules, it must be ``None``, and the module itself 539 dictates the output location. (default: ``device_ids[0]`` 540 for single-device modules) 541 broadcast_buffers (bool): Flag that enables syncing (broadcasting) 542 buffers of the module at beginning of the ``forward`` 543 function. (default: ``True``) 544 process_group: The process group to be used for distributed data 545 all-reduction. If ``None``, the default process group, which 546 is created by :func:`torch.distributed.init_process_group`, 547 will be used. (default: ``None``) 548 bucket_cap_mb: ``DistributedDataParallel`` will bucket parameters into 549 multiple buckets so that gradient reduction of each 550 bucket can potentially overlap with backward computation. 551 :attr:`bucket_cap_mb` controls the bucket size in 552 MebiBytes (MiB). If ``None``, a default size of 25 MiB 553 will be used. (default: ``None``) 554 find_unused_parameters (bool): Traverse the autograd graph from all 555 tensors contained in the return value of the 556 wrapped module's ``forward`` function. Parameters 557 that don't receive gradients as part of this 558 graph are preemptively marked as being ready to 559 be reduced. In addition, parameters that may have 560 been used in the wrapped module's ``forward`` 561 function but were not part of loss computation and 562 thus would also not receive gradients are 563 preemptively marked as ready to be reduced. 564 (default: ``False``) 565 check_reduction: This argument is deprecated. 566 gradient_as_bucket_view (bool): When set to ``True``, gradients will be views 567 pointing to different offsets of ``allreduce`` communication 568 buckets. This can reduce peak memory usage, where the 569 saved memory size will be equal to the total gradients 570 size. Moreover, it avoids the overhead of copying between 571 gradients and ``allreduce`` communication buckets. When 572 gradients are views, ``detach_()`` cannot be called on the 573 gradients. If hitting such errors, please fix it by 574 referring to the :meth:`~torch.optim.Optimizer.zero_grad` 575 function in ``torch/optim/optimizer.py`` as a solution. 576 Note that gradients will be views after first iteration, so 577 the peak memory saving should be checked after first iteration. 578 static_graph (bool): When set to ``True``, DDP knows the trained graph is 579 static. Static graph means 1) The set of used and unused 580 parameters will not change during the whole training loop; in 581 this case, it does not matter whether users set 582 ``find_unused_parameters = True`` or not. 2) How the graph is trained 583 will not change during the whole training loop (meaning there is 584 no control flow depending on iterations). 585 When static_graph is set to be ``True``, DDP will support cases that 586 can not be supported in the past: 587 1) Reentrant backwards. 588 2) Activation checkpointing multiple times. 589 3) Activation checkpointing when model has unused parameters. 590 4) There are model parameters that are outside of forward function. 591 5) Potentially improve performance when there are unused parameters, 592 as DDP will not search graph in each iteration to detect unused 593 parameters when static_graph is set to be ``True``. 594 To check whether you can set static_graph to be ``True``, one way is to 595 check ddp logging data at the end of your previous model training, 596 if ``ddp_logging_data.get("can_set_static_graph") == True``, mostly you 597 can set ``static_graph = True`` as well. 598 599 Example:: 600 >>> # xdoctest: +SKIP("undefined variables") 601 >>> model_DDP = torch.nn.parallel.DistributedDataParallel(model) 602 >>> # Training loop 603 >>> ... 604 >>> ddp_logging_data = model_DDP._get_ddp_logging_data() 605 >>> static_graph = ddp_logging_data.get("can_set_static_graph") 606 delay_all_reduce_named_params (list of tuple of str and torch.nn.Parameter): a list 607 of named parameters whose all reduce will be delayed when the gradient of 608 the parameter specified in ``param_to_hook_all_reduce`` is ready. Other 609 arguments of DDP do not apply to named params specified in this argument 610 as these named params will be ignored by DDP reducer. 611 param_to_hook_all_reduce (torch.nn.Parameter): a parameter to hook delayed all reduce 612 of parameters specified in ``delay_all_reduce_named_params``. 613 614 615 Attributes: 616 module (Module): the module to be parallelized. 617 618 Example:: 619 620 >>> # xdoctest: +SKIP("undefined variables") 621 >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...') 622 >>> net = torch.nn.parallel.DistributedDataParallel(model) 623 """ 624 625 # used to track whether the given thread is inside ddp forward for torchdynamo purposes 626 _active_ddp_module: Optional["DistributedDataParallel"] = None 627 628 def __init__( 629 self, 630 module, 631 device_ids=None, 632 output_device=None, 633 dim=0, 634 broadcast_buffers=True, 635 process_group=None, 636 bucket_cap_mb=None, 637 find_unused_parameters=False, 638 check_reduction=False, 639 gradient_as_bucket_view=False, 640 static_graph=False, 641 delay_all_reduce_named_params=None, 642 param_to_hook_all_reduce=None, 643 mixed_precision: Optional[_MixedPrecision] = None, 644 device_mesh=None, 645 ): 646 super().__init__() 647 Joinable.__init__(self) 648 self.logger: Optional[dist.Logger] = None 649 if bool(delay_all_reduce_named_params is not None) != bool( 650 param_to_hook_all_reduce is not None 651 ): 652 self._log_and_throw( 653 ValueError, 654 "delay_all_reduce_named_params and param_to_hook_all_reduce " 655 "need to be set at the same time.", 656 ) 657 658 if process_group and device_mesh is not None: 659 raise RuntimeError( 660 "Cannot specify both process_group and device_mesh arguments." 661 ) 662 elif process_group is None and device_mesh is None: 663 self.process_group = _get_default_group() 664 elif device_mesh is None: 665 self.process_group = process_group 666 else: 667 if device_mesh.ndim != 1: 668 raise RuntimeError( 669 f"Only 1D device mesh is supported, but got {device_mesh}." 670 ) 671 self.device_mesh = device_mesh 672 self.process_group = device_mesh.get_group(mesh_dim=0) 673 from torch.distributed.device_mesh import _mesh_resources 674 675 root_mesh = _mesh_resources.get_root_mesh(device_mesh) 676 # if a root mesh is not the same as device_mesh, 677 # meaning the device_mesh is sliced out from the root mesh. 678 if root_mesh != device_mesh: 679 # TODO: This is a temporary work around to enable DDP + TP. 680 # We should do the logic in DDP so that the 2D implementation is 681 # sound and the state_dict works out of the box. 682 # This has to be done before check UninitializedParameter. 683 from torch.distributed.tensor.parallel.ddp import ( 684 _pre_dp_module_transform, 685 ) 686 687 _pre_dp_module_transform(module) 688 689 self._delay_all_reduce_params = [] 690 if hasattr(module, "_ddp_params_and_buffers_to_ignore"): 691 self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore) 692 else: 693 self.parameters_to_ignore = set() 694 if delay_all_reduce_named_params is not None: 695 for name, param in delay_all_reduce_named_params: 696 self.parameters_to_ignore.add(name) 697 self._delay_all_reduce_params.append(param) 698 699 self._module_parameters = [ 700 p 701 for n, p in module.named_parameters() 702 if n not in self.parameters_to_ignore 703 ] 704 if not any(p.requires_grad for p in self._module_parameters): 705 if len(self._delay_all_reduce_params): 706 logger.info("Delay the AllReduce of all parameters.") 707 else: 708 self._log_and_throw( 709 RuntimeError, 710 "DistributedDataParallel is not needed when a module " 711 "doesn't have any parameter that requires a gradient.", 712 ) 713 714 if device_ids is not None and len(device_ids) > 1: 715 self._log_and_throw( 716 ValueError, 717 "device_ids can only be None or contain a single element.", 718 ) 719 720 self.is_multi_device_module = ( 721 len({p.device for p in self._module_parameters}) > 1 722 ) 723 distinct_device_types = { 724 p.device.type for p in self._module_parameters if p.device is not None 725 } 726 if len(distinct_device_types) != 1: 727 self._log_and_throw( 728 ValueError, 729 "DistributedDataParallel's input module must be on " 730 f"the same type of devices, but input module parameters locate in {distinct_device_types}.", 731 ) 732 733 self.device_type = next(iter(distinct_device_types)) 734 735 if ( 736 device_ids is None 737 or len(device_ids) == 0 # For backward compatibility. 738 or self.device_type == "cpu" 739 or self.is_multi_device_module 740 ): 741 if device_ids or output_device: 742 self._log_and_throw( 743 ValueError, 744 "DistributedDataParallel device_ids and output_device arguments " 745 "only work with single-device/multiple-device GPU modules or CPU modules, " 746 f"but got device_ids {device_ids}, output_device {output_device}, " 747 f"and module parameters {({p.device for p in self._module_parameters})}.", 748 ) 749 750 self.device_ids = None 751 self.output_device = None 752 else: 753 self.device_ids = [_get_device_index(x, True) for x in device_ids] 754 755 if output_device is None: 756 output_device = device_ids[0] 757 758 self.output_device = _get_device_index(output_device, True) 759 760 self.static_graph = False 761 self.dim = dim 762 self.module = module 763 self.device = next(iter(self._module_parameters)).device 764 self.broadcast_buffers = broadcast_buffers 765 self.find_unused_parameters = find_unused_parameters 766 self.require_backward_grad_sync = True 767 self.require_forward_param_sync = True 768 self.gradient_as_bucket_view = gradient_as_bucket_view 769 self.mixed_precision = mixed_precision 770 if self.mixed_precision is not None: 771 logger.warning("Received mixed precision config %s", self.mixed_precision) 772 773 if check_reduction: 774 # This argument is no longer used since the reducer 775 # will ensure reduction completes even if some parameters 776 # do not receive gradients. 777 warnings.warn( 778 "The `check_reduction` argument in `DistributedDataParallel` " 779 "module is deprecated. Please avoid using it.", 780 FutureWarning, 781 stacklevel=2, 782 ) 783 784 # Check that a module does not have Uninitialized parameters 785 for param in self._module_parameters: 786 if isinstance(param, torch.nn.parameter.UninitializedParameter): 787 self._log_and_throw( 788 RuntimeError, 789 "Modules with uninitialized parameters can't be used with `DistributedDataParallel`. " 790 "Run a dummy forward pass to correctly initialize the modules", 791 ) 792 # used for intra-node param sync and inter-node sync as well 793 self.broadcast_bucket_size = int(250 * 1024 * 1024) 794 795 # reduction bucket size 796 if bucket_cap_mb is None: 797 # default case (bucket cap is 25 MiB) 798 bucket_cap_mb = 25 799 self.bucket_bytes_cap_default = True 800 else: 801 self.bucket_bytes_cap_default = False 802 self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024) 803 804 # Whether to perform input tensor CPU to GPU copies on a side-stream 805 self.use_side_stream_for_tensor_copies = ( 806 os.environ.get("PYTORCH_DDP_USE_SIDE_STREAM", "1") == "1" 807 ) 808 809 # Initialize gradient buffers and register all reduce hook 810 self._delay_grad_buffer: Optional[torch.Tensor] = None 811 self._delay_grad_views: List[torch.Tensor] = [] 812 self._delay_all_reduce_all_params = False 813 if len(self._delay_all_reduce_params) != 0: 814 self._register_delay_all_reduce_hook( 815 bucket_cap_mb=bucket_cap_mb, 816 param_to_hook_all_reduce=param_to_hook_all_reduce, 817 device_ids=device_ids, 818 ) 819 if self._delay_all_reduce_all_params: 820 return 821 822 # Build parameters for reducer. 823 parameters, expect_sparse_gradient = self._build_params_for_reducer() 824 # Verify model equivalence. 825 _verify_param_shape_across_processes(self.process_group, parameters) 826 # Sync params and buffers. Ensures all DDP models start off at the same value. 827 _sync_module_states( 828 module=self.module, 829 process_group=self.process_group, 830 broadcast_bucket_size=self.broadcast_bucket_size, 831 src=0, 832 params_and_buffers_to_ignore=self.parameters_to_ignore, 833 broadcast_buffers=self.broadcast_buffers, 834 ) 835 # In debug mode, build a mapping of parameter index -> parameter. 836 param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters) 837 838 # Builds reducer. 839 self._ddp_init_helper( 840 parameters, 841 expect_sparse_gradient, 842 param_to_name_mapping, 843 static_graph, 844 ) 845 self._comm_hooks: List[Tuple[Callable, object]] = [] 846 847 if self.mixed_precision is not None: 848 _setup_mixed_precision_params(self.mixed_precision, self.module) 849 _cast_buffers(self.mixed_precision, self.module) 850 # Stream used for async low precision copies. 851 self._mp_stream = torch.cuda.Stream() 852 self._submodule_to_event = defaultdict(deque) # type: ignore[var-annotated] 853 # Add forward pre-hook to root module to kick off copies to lower 854 # precision. 855 self.module.register_forward_pre_hook( 856 self._root_copy_hook, prepend=False, with_kwargs=True 857 ) 858 # Add forward pre hook to all submodules to wait for copy events 859 # before running computation. 860 for module in self.module.modules(): 861 module.register_forward_pre_hook( 862 self._module_wait_for_copy_hook, 863 prepend=False, 864 with_kwargs=True, 865 ) 866 # Set up callbacks in backward to upcast and use full precision 867 # params. TODO (rohan-varma): Make this compose with general 868 # comm hooks and apply_optimizer_in_backward. Importing inline to 869 # avoid circular import issue. 870 from torch.distributed.algorithms.ddp_comm_hooks.mixed_precision_hooks import ( 871 _AllreduceUpcastHookState, 872 _reducer_allreduce_and_upcast_hook, 873 ) 874 875 upcast_hook_state = _AllreduceUpcastHookState( 876 ddp_weakref=weakref.ref(self), 877 upcast_stream=torch.cuda.Stream(), 878 ) 879 self.register_comm_hook( 880 upcast_hook_state, 881 _reducer_allreduce_and_upcast_hook, 882 ) 883 # Inform reducer of reduced precision param dtype for correctness 884 # of type checks between gradient and bucket. 885 self.reducer._set_mixed_precision_param_dtype( # type: ignore[attr-defined] 886 self.mixed_precision.param_dtype 887 ) 888 889 self._has_rebuilt_buckets = False 890 891 if static_graph: 892 self._set_static_graph() 893 894 self._lazy_init_ran = False 895 896 # Register the AccumulateGrad post hooks if optimize_ddp is 897 # True. The hooks will be deregistered if compiled_autograd is not 898 # enabled. 899 self._accum_grad_hooks: List[RemovableHandle] = [] 900 optimize_ddp = torch._dynamo.config._get_optimize_ddp_mode() 901 self._use_python_reducer = optimize_ddp in ( 902 "python_reducer", 903 "python_reducer_without_compiled_forward", 904 ) 905 if self._use_python_reducer: 906 torch._inductor.config._fuse_ddp_communication = True 907 torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb 908 # Directly adding this to the trace rule will disturb the users 909 # who are using DDPOptimizer. 910 torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST.add( 911 "torch.nn.parallel.distributed" 912 ) 913 torch._dynamo.trace_rules.get_legacy_mod_inlinelist.cache_clear() 914 self._force_to_disable_cpp_reducer = ( 915 optimize_ddp == "python_reducer_without_compiled_forward" 916 ) 917 if self._use_python_reducer: 918 self._register_accum_grad_hook() 919 920 # Whether or not DDPSink performs a clone. 921 self._ddp_sink_clone = True 922 923 def _register_accum_grad_hook(self): 924 import torch.distributed._functional_collectives as fcol 925 926 def compiled_accum_grad_hook( 927 param, 928 *, 929 param_index: int, 930 ): 931 if not self.require_backward_grad_sync: 932 return 933 934 if param.grad is None: 935 return 936 937 if self._comm_hooks: 938 for hook, state in self._comm_hooks: 939 hook(state, (param.grad, param)) 940 else: 941 gradient = param.grad / self.process_group.size() 942 gradient = fcol.all_reduce(gradient, "sum", self.process_group) 943 param.grad.copy_(gradient) 944 945 for index, param in enumerate(self._module_parameters): 946 if not param.requires_grad: 947 continue 948 self._accum_grad_hooks.append( 949 param.register_post_accumulate_grad_hook( 950 functools.partial( 951 compiled_accum_grad_hook, 952 param_index=index, 953 ) 954 ) 955 ) 956 957 def _delayed_all_reduce_hook(self, grad): 958 world_size = dist.get_world_size(self.process_group) 959 960 self._delay_grad_buffer.div_(world_size) # type: ignore[union-attr] 961 _ = dist.all_reduce( 962 self._delay_grad_buffer, group=self.process_group, async_op=True 963 ) 964 return grad 965 966 def _register_delay_all_reduce_hook( 967 self, 968 bucket_cap_mb, 969 param_to_hook_all_reduce, 970 device_ids, 971 ): 972 # 1. Create gradient buffer 973 device = torch.device("cpu") if device_ids is None else device_ids[0] 974 self._delay_grad_buffer = torch.zeros( 975 sum(p.numel() for p in self._delay_all_reduce_params), 976 device=device, 977 ) 978 979 # 2. Broadcast the parameters 980 detached_params = [p.detach() for p in self._delay_all_reduce_params] 981 dist._broadcast_coalesced(self.process_group, detached_params, bucket_cap_mb, 0) 982 983 # 3. Hook all reduce to the specified parameter 984 param_to_hook_all_reduce.register_hook(self._delayed_all_reduce_hook) 985 986 # 4. Build tensor views for gradients 987 offset = 0 988 for param in self._delay_all_reduce_params: 989 grad_view = self._delay_grad_buffer[offset : (offset + param.numel())].view( 990 param.shape 991 ) 992 self._delay_grad_views.append(grad_view) 993 offset = offset + param.numel() 994 995 # 5. Check whether the all reduce of all params requiring grad is delayed. 996 for module_name, module in self.module.named_modules(): 997 for param_name, param in module.named_parameters(recurse=False): 998 if param.requires_grad: 999 full_name = f"{module_name}.{param_name}" 1000 if full_name not in self.parameters_to_ignore: 1001 # There is at least a param whose all reduce will not be delayed. 1002 # In this case, we should not set self._delay_all_reduce_all_params 1003 # to True. 1004 return 1005 self._delay_all_reduce_all_params = True 1006 1007 def _setup_in_backward_optimizers(self): 1008 # Check if user has used apply_optim_in_backward to overlap optimizer 1009 # step + DDP backward. Current constraints: 1010 # 1. Only allreduce is supported at the moment, no custom communication. 1011 # 2. For DDP-managed parameters that have their optimizer run in 1012 # backward, their gradients are set to ``None``. If your use case 1013 # requires DDP parameters grad not to be set to ``None`` after their 1014 # in-backward optimizer runs, please ping 1015 # https://github.com/pytorch/pytorch/issues/90052. 1016 # NOTE: we use self._module_parameters instead of .parameters() since 1017 # the former excludes ignored (non-DDP managed) parameters. 1018 if any(hasattr(p, "_in_backward_optimizers") for p in self._module_parameters): 1019 torch._C._log_api_usage_once("ddp.optimizer_in_backward") 1020 # Remove hooks that apply_optim_in_backward had registered because 1021 # DDP customizes how optimizer is overlapped with backward due to 1022 # the allreduce. 1023 param_to_handle_map = ( 1024 dist.optim.apply_optimizer_in_backward.param_to_optim_hook_handle_map 1025 ) 1026 for p in self._module_parameters: 1027 for handle in param_to_handle_map.get(p, []): 1028 handle.remove() 1029 1030 # Need a weakref to DDP instance to run all_reduce (from reducer) 1031 # and get managed DDP parameters. 1032 ddp_weakref = weakref.ref(self) 1033 # Note: importing in function, otherwise this will cause a circular 1034 # import. 1035 from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import ( 1036 _apply_optim_in_backward_hook, 1037 ) 1038 1039 self.register_comm_hook( 1040 ddp_weakref, 1041 _apply_optim_in_backward_hook( 1042 gradient_is_bucket_view=self.gradient_as_bucket_view 1043 ), 1044 ) 1045 1046 self.reducer._set_optimizer_in_backward() # type: ignore[attr-defined] 1047 1048 def _fire_reducer_autograd_hook(self, idx, *unused): 1049 """ 1050 Fire the reducer's autograd hook to allreduce params in a Reducer bucket. 1051 1052 Note that this is only used during mixed precision training as the 1053 Reducer's hooks installed during construction time would not be called 1054 as we're working in the low precision parameter setting. 1055 """ 1056 self.reducer._autograd_hook(idx) # type: ignore[attr-defined] 1057 1058 def _root_copy_hook(self, *args: Any, **kwargs: Any) -> None: 1059 """ 1060 For DDP mixed precision, put low precision copies on separate stream and create events to wait for them. 1061 1062 When training with DDP mixed precision, this root pre-forward hook kicks 1063 off low precision copies on a separate stream and creates respective 1064 events to wait for them. 1065 """ 1066 # Clear out previous iteration submodule to event. This is because we 1067 # may have populated some events for modules that didn't end up being 1068 # used. 1069 self._submodule_to_event = defaultdict(deque) # type: ignore[var-annotated] 1070 with torch.cuda.stream(self._mp_stream): 1071 for submodule in self.module.modules(): 1072 for param in submodule.parameters(recurse=False): 1073 # Do not cast DDP ignored parameters. 1074 if hasattr(param, "_ddp_ignored") and param._ddp_ignored: 1075 continue 1076 _alloc_storage(param._mp_param, param.size()) 1077 # copy() implicitly casts to low precision 1078 with torch.no_grad(): 1079 param._mp_param.copy_(param.data) 1080 # TODO: when zero_grad(set_to_none=False) or in grad 1081 # accumulation case, accumulated grads can be in fp32 1082 # which can cause errors when running DDP backwards due 1083 # to mismatched incoming and accumulated gradient types. 1084 # So we manually cast the accumulated grad down for now, 1085 # in the future we may shift to FSDP style gradient 1086 # accumulation management where the accumulated gradient 1087 # is saved and .grad field is set to None, bypassing 1088 # this issue. 1089 if param.grad is not None: 1090 param.grad.data = param.grad.to( 1091 self.mixed_precision.param_dtype # type: ignore[union-attr] 1092 ) 1093 param.data = param._mp_param 1094 copy_event = torch.cuda.Event() 1095 copy_event.record() 1096 self._submodule_to_event[submodule].append(copy_event) 1097 1098 def _module_wait_for_copy_hook( 1099 self, 1100 module, 1101 *args: Any, 1102 **kwargs: Any, 1103 ) -> None: 1104 """Before carrying out computation, wait on the appropriate event to ensure low precision copies have finished.""" 1105 try: 1106 event = self._submodule_to_event[module].popleft() 1107 except IndexError: 1108 # copy event has already been waited on 1109 return 1110 1111 event.wait(stream=torch.cuda.current_stream()) 1112 for p in module.parameters(recurse=False): 1113 # Don't register hooks if param does not require grad 1114 if not p.requires_grad or (hasattr(p, "_ddp_ignored") and p._ddp_ignored): 1115 continue 1116 # We need to register autograd hook here instead of DDP's ctor 1117 # since we're working with the low precision param. Register them 1118 # via obtaining the gradient accumulator. 1119 tmp = p.expand_as(p) 1120 grad_acc = tmp.grad_fn.next_functions[0][0] 1121 1122 hook = grad_acc.register_hook( 1123 functools.partial(self._fire_reducer_autograd_hook, p._idx) 1124 ) 1125 p._ddp_mp_hook_state = (grad_acc, hook) 1126 1127 def _log_and_throw(self, err_type, err_msg): 1128 if self.logger is not None: 1129 self.logger.set_error_and_log(f"{str(err_type)}: {err_msg}") 1130 raise err_type(err_msg) 1131 1132 def _ddp_init_helper( 1133 self, 1134 parameters, 1135 expect_sparse_gradient, 1136 param_to_name_mapping, 1137 static_graph, 1138 ): 1139 """ 1140 DDP init helper function to manage parameters, grad hooks, logging, and SyncBatchNorm. 1141 1142 Initialization helper function that does the following: 1143 (1) bucketing the parameters for reductions 1144 (2) resetting the bucketing states 1145 (3) registering the grad hooks 1146 (4) Logging construction-time DDP logging data 1147 (5) passing a handle of DDP to SyncBatchNorm Layer 1148 """ 1149 # Notice, the parameters order is not in the order in which they are used, 1150 # especially in models with control flow. 1151 # 1152 # Alongside parameters are not presented in the real execution order, 1153 # if a certain model happens to also 1154 # 1) have other collectives comm ops in its backward graph. 1155 # 2) have unused parameter in subset ranks of the whole world. 1156 # bucketing could insert ALL-REDUCE comm op too early on the rank with unused parameter, 1157 # matching up with other collectives comm ops on other ranks unexpectedly. 1158 # 1159 # In order to handle this corner case, when the parameters are not in the real execution order, 1160 # we don't do bucketing, thus only one ALL-REDUCE is inserted after all the gradients 1161 # of the whole graph are computed. 1162 # 1163 # Notice, here we only disable bucketing for the first iteration. 1164 # After the first iteration, it's OK to rebuild buckets, 1165 # because "bucket rebuild" bucketizes parameters based on its real execution order in backward graph. 1166 1167 # Can remove this branching once #73732 is landed. 1168 if static_graph is True or self.find_unused_parameters is False: 1169 bucket_size_limits = [sys.maxsize] 1170 else: 1171 if self.bucket_bytes_cap_default: 1172 bucket_size_limits = [ 1173 dist._DEFAULT_FIRST_BUCKET_BYTES, 1174 self.bucket_bytes_cap, 1175 ] 1176 else: 1177 bucket_size_limits = [self.bucket_bytes_cap] 1178 ( 1179 bucket_indices, 1180 per_bucket_size_limits, 1181 ) = dist._compute_bucket_assignment_by_size( 1182 parameters, 1183 bucket_size_limits, 1184 expect_sparse_gradient, 1185 ) 1186 1187 # Remember index for parameters if we are in mixed precision, as we 1188 # need to pass in index to Reducer's autograd hook via python. 1189 if self.mixed_precision is not None: 1190 for i, p in enumerate(parameters): 1191 p._idx = i 1192 1193 # Note: reverse list of buckets because we want to approximate the 1194 # order in which their gradients are produced, and assume they 1195 # are used in the forward pass in the order they are defined. 1196 self.reducer = dist.Reducer( 1197 parameters, 1198 list(reversed(bucket_indices)), 1199 list(reversed(per_bucket_size_limits)), 1200 self.process_group, 1201 expect_sparse_gradient, 1202 # The bucket size limit is specified in the constructor. 1203 # Additionally, we allow for a single small bucket for parameters 1204 # that are defined first, such that their gradients don't spill into 1205 # a much larger bucket, adding unnecessary latency after gradient 1206 # computation finishes. Experiments showed 1MB is a reasonable value. 1207 self.bucket_bytes_cap, 1208 self.find_unused_parameters, 1209 self.gradient_as_bucket_view, 1210 param_to_name_mapping, 1211 # User can set dist._DEFAULT_FIRST_BUCKET_BYTES to tune DDP first 1212 # bucket. 1213 ( 1214 dist._DEFAULT_FIRST_BUCKET_BYTES 1215 if self.bucket_bytes_cap_default 1216 else self.bucket_bytes_cap 1217 ), 1218 ) 1219 1220 self.logger = dist.Logger(self.reducer) 1221 # Set as a weak reference to avoid reference cycle between 1222 # logger and reducer. 1223 self.reducer.set_logger(self.logger) 1224 1225 has_sync_bn = False 1226 for submodule in self.module.modules(): 1227 if isinstance(submodule, torch.nn.SyncBatchNorm): 1228 has_sync_bn = True 1229 break 1230 1231 # Set logging data that can be got during construction time. 1232 self.logger.set_construction_data_and_log( 1233 self.module.__class__.__name__, 1234 [] if self.device_ids is None else self.device_ids, 1235 -1 if self.output_device is None else self.output_device, 1236 self.broadcast_buffers, 1237 has_sync_bn, 1238 static_graph, 1239 ) 1240 1241 # passing a handle to torch.nn.SyncBatchNorm layer 1242 self._passing_sync_batchnorm_handle(self.module) 1243 1244 def __getstate__(self): 1245 self._check_default_group() 1246 attrs = copy.copy(self.__dict__) 1247 del attrs["process_group"] 1248 del attrs["reducer"] 1249 del attrs["logger"] 1250 return attrs 1251 1252 def __setstate__(self, state): 1253 # If serializable, then the process group should be the default one 1254 self.process_group = _get_default_group() 1255 super().__setstate__(state) 1256 self.__dict__.setdefault("require_forward_param_sync", True) 1257 self.__dict__.setdefault("require_backward_grad_sync", True) 1258 parameters, expect_sparse_gradient = self._build_params_for_reducer() 1259 # In debug mode, build a mapping of parameter index -> parameter. 1260 param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters) 1261 # Builds reducer. 1262 self._ddp_init_helper( 1263 parameters, 1264 expect_sparse_gradient, 1265 param_to_name_mapping, 1266 self.static_graph, 1267 ) 1268 if self.static_graph: 1269 self.reducer._set_static_graph() 1270 assert self.logger is not None 1271 self.logger._set_static_graph() 1272 1273 def _build_params_for_reducer(self): 1274 # Build tuple of (module, parameter) for all parameters that require grads. 1275 modules_and_parameters = [ 1276 (module, parameter) 1277 for module_name, module in self.module.named_modules() 1278 for parameter in [ 1279 param 1280 # Note that we access module.named_parameters instead of 1281 # parameters(module). parameters(module) is only needed in the 1282 # single-process multi device case, where it accesses replicated 1283 # parameters through _former_parameters. 1284 for param_name, param in module.named_parameters(recurse=False) 1285 if param.requires_grad 1286 and f"{module_name}.{param_name}" not in self.parameters_to_ignore 1287 ] 1288 ] 1289 1290 # Deduplicate any parameters that might be shared across child modules. 1291 memo = set() 1292 modules_and_parameters = [ 1293 # "p not in memo" is the deduplication check. 1294 # "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed. 1295 (m, p) 1296 for m, p in modules_and_parameters 1297 if p not in memo and not memo.add(p) # type: ignore[func-returns-value] 1298 ] 1299 1300 # Build list of parameters. 1301 parameters = [parameter for _, parameter in modules_and_parameters] 1302 1303 # Checks if a module will produce a sparse gradient. 1304 def produces_sparse_gradient(module): 1305 if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)): 1306 return module.sparse 1307 return False 1308 1309 # Build list of booleans indicating whether or not to expect sparse 1310 # gradients for the corresponding parameters. 1311 expect_sparse_gradient = [ 1312 produces_sparse_gradient(module) for module, _ in modules_and_parameters 1313 ] 1314 1315 self._assign_modules_buffers() 1316 1317 return parameters, expect_sparse_gradient 1318 1319 def _assign_modules_buffers(self): 1320 """ 1321 Assign self.module.named_buffers to self.modules_buffers. 1322 1323 Assigns module buffers to self.modules_buffers which are then used to 1324 broadcast across ranks when broadcast_buffers=True. Note that this 1325 must be called every time buffers need to be synced because buffers can 1326 be reassigned by user module, 1327 see https://github.com/pytorch/pytorch/issues/63916. 1328 """ 1329 # Collect buffers for modules, filtering out buffers that should be ignored. 1330 named_module_buffers = [ 1331 (buffer, buffer_name) 1332 for buffer_name, buffer in self.module.named_buffers() 1333 if buffer_name not in self.parameters_to_ignore 1334 ] 1335 self.modules_buffers = [ 1336 buffer for (buffer, buffer_name) in named_module_buffers 1337 ] 1338 # Dict[str, tensor] representing module buffers not ignored by DDP. 1339 self.named_module_buffers = { 1340 buffer_name: buffer for (buffer, buffer_name) in named_module_buffers 1341 } 1342 1343 def _build_debug_param_to_name_mapping(self, parameters): 1344 param_to_param_index = {parameters[i]: i for i in range(len(parameters))} 1345 param_set = set(parameters) 1346 param_index_to_param_fqn = {} 1347 for module_name, module in self.module.named_modules(): 1348 for param_name, param in module.named_parameters(recurse=False): 1349 fqn = f"{module_name}.{param_name}" 1350 # Bypass ignored parameters since those are not reduced by DDP 1351 # to begin with. 1352 if fqn not in self.parameters_to_ignore and param.requires_grad: 1353 if param not in param_set: 1354 self._log_and_throw( 1355 ValueError, 1356 f"Param with name {fqn} found in module parameters, but not DDP parameters." 1357 " This indicates a bug in DDP, please report an issue to PyTorch.", 1358 ) 1359 param_index = param_to_param_index[param] 1360 param_index_to_param_fqn[param_index] = fqn 1361 1362 # Ensure we covered all parameters 1363 if len(param_set) != len(param_index_to_param_fqn): 1364 self._log_and_throw( 1365 ValueError, 1366 ( 1367 "Expected param to name mapping to cover all parameters, but" 1368 f" got conflicting lengths: {len(param_set)} vs " 1369 f"{len(param_index_to_param_fqn)}. This indicates a bug in DDP" 1370 ", please report an issue to PyTorch." 1371 ), 1372 ) 1373 1374 return param_index_to_param_fqn 1375 1376 def _get_parameters(self, m, recurse=True): 1377 """Return a generator of module parameters.""" 1378 1379 def model_parameters(m): 1380 ps = ( 1381 m._former_parameters.values() 1382 if hasattr(m, "_former_parameters") 1383 else m.parameters(recurse=False) 1384 ) 1385 yield from ps 1386 1387 for mod in m.modules() if recurse else [m]: 1388 yield from model_parameters(mod) 1389 1390 def _check_default_group(self): 1391 pickle_not_supported = False 1392 try: 1393 if self.process_group != _get_default_group(): 1394 pickle_not_supported = True 1395 except RuntimeError: 1396 pickle_not_supported = True 1397 1398 if pickle_not_supported: 1399 self._log_and_throw( 1400 RuntimeError, 1401 "DDP Pickling/Unpickling are only supported " 1402 "when using DDP with the default process " 1403 "group. That is, when you have called " 1404 "init_process_group and have not passed " 1405 "process_group argument to DDP constructor", 1406 ) 1407 1408 @contextmanager 1409 def no_sync(self): 1410 r""" 1411 Context manager to disable gradient synchronizations across DDP processes. 1412 1413 Within this context, gradients will be accumulated on module 1414 variables, which will later be synchronized in the first 1415 forward-backward pass exiting the context. 1416 1417 Example:: 1418 1419 >>> # xdoctest: +SKIP("undefined variables") 1420 >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg) 1421 >>> with ddp.no_sync(): 1422 >>> for input in inputs: 1423 >>> ddp(input).backward() # no synchronization, accumulate grads 1424 >>> ddp(another_input).backward() # synchronize grads 1425 1426 .. warning:: 1427 The forward pass should be included inside the context manager, or 1428 else gradients will still be synchronized. 1429 """ 1430 old_require_backward_grad_sync = self.require_backward_grad_sync 1431 self.require_backward_grad_sync = False 1432 try: 1433 yield 1434 finally: 1435 self.require_backward_grad_sync = old_require_backward_grad_sync 1436 1437 @classmethod 1438 def _get_active_ddp_module(cls): 1439 """`TorchDynamo` requires DDP's status and module for cooperative optimization.""" 1440 return cls._active_ddp_module 1441 1442 # note, this ctxmgr function is marked 'skip' in torchdynamo, so dynamo only kicks in 1443 # for the 'module_to_run' underneath 1444 # see torch._dynamo/eval_frame.py TorchPatcher.patch for more details 1445 @contextmanager 1446 @torch._disable_dynamo(recursive=False) 1447 def _inside_ddp_forward(self): 1448 DistributedDataParallel._active_ddp_module = self 1449 try: 1450 yield 1451 finally: 1452 DistributedDataParallel._active_ddp_module = None 1453 1454 def _run_ddp_forward(self, *inputs, **kwargs): 1455 if self._use_python_reducer: 1456 return self.module(*inputs, **kwargs) # type: ignore[index] 1457 else: 1458 with self._inside_ddp_forward(): 1459 return self.module(*inputs, **kwargs) # type: ignore[index] 1460 1461 def _clear_grad_buffer(self): 1462 # Making param.grad points to the grad buffers before backward is based on the 1463 # assumption that the grad accumulation is done in place in autograd engine, 1464 # for some edge cases, if the grad accumulation in autograd engine is not in 1465 # place, then the param.grad and grad buffers are detached. 1466 if self._delay_grad_buffer is not None: 1467 # We batch zero_grad for all params by resetting the whole grad 1468 # buffer when the grad of all params is set to None. 1469 all_param_grad_none = all( 1470 param.grad is None for param in self._delay_all_reduce_params 1471 ) 1472 1473 for index, param in enumerate(self._delay_all_reduce_params): 1474 if param.grad is None: 1475 param.grad = self._delay_grad_views[index] 1476 if not all_param_grad_none: 1477 param.grad.zero_() 1478 1479 if all_param_grad_none: 1480 self._delay_grad_buffer.zero_() 1481 1482 def _lazy_init(self): 1483 # Initialization for DDP that occurs after construction, but lazily 1484 # before the first forward pass. 1485 self._setup_in_backward_optimizers() 1486 self._lazy_init_ran = True 1487 1488 def _should_disable_cpp_reducer(self) -> bool: 1489 return self._use_python_reducer and ( 1490 torch._utils.is_compiling() or self._force_to_disable_cpp_reducer 1491 ) 1492 1493 def _pre_forward(self, *inputs, **kwargs): 1494 if self._should_disable_cpp_reducer(): 1495 return inputs, kwargs 1496 1497 # Disable the python reducer if compiled_autograd is not enabled. 1498 if self._accum_grad_hooks: 1499 for index, h in enumerate(self._accum_grad_hooks): 1500 h.remove() 1501 self._accum_grad_hooks.clear() 1502 1503 if not self._lazy_init_ran and not torch._utils.is_compiling(): 1504 self._lazy_init() 1505 1506 if self._delay_all_reduce_all_params: 1507 return inputs, kwargs 1508 1509 if torch.is_grad_enabled() and self.require_backward_grad_sync: 1510 assert self.logger is not None 1511 self.logger.set_runtime_stats_and_log() 1512 self.reducer.prepare_for_forward() 1513 1514 # Notify the join context that this process has not joined, if 1515 # needed 1516 work = Join.notify_join_context(self) 1517 if work: 1518 self.reducer._set_forward_pass_work_handle( 1519 work, self._divide_by_initial_world_size # type: ignore[arg-type] 1520 ) 1521 1522 # Calling _rebuild_buckets before forward computation, 1523 # It may allocate new buckets before deallocating old buckets 1524 # inside _rebuild_buckets. To save peak memory usage, 1525 # call _rebuild_buckets before the peak memory usage increases 1526 # during forward computation. 1527 # This should be called only once during whole training period. 1528 if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): 1529 logger.info("Reducer buckets have been rebuilt in this iteration.") 1530 self._has_rebuilt_buckets = True 1531 1532 # sync params according to location (before/after forward) user 1533 # specified as part of hook, if hook was specified. 1534 if self._check_sync_bufs_pre_fwd(): 1535 self._sync_buffers() 1536 1537 if self._join_config.enable: 1538 # Notify joined ranks whether they should sync in backwards pass or not. 1539 self._check_global_requires_backward_grad_sync(is_joined_rank=False) 1540 1541 if self.device_ids: 1542 moved_inputs, moved_kwargs = _to_kwargs( 1543 inputs, 1544 kwargs, 1545 torch.device(self.device_type, self.device_ids[0]), 1546 self.use_side_stream_for_tensor_copies, 1547 ) 1548 args, kwargs = moved_inputs[0], moved_kwargs[0] 1549 # Cast inputs to reduced precision if needed. 1550 if self.mixed_precision is not None: 1551 args, kwargs = _cast_forward_inputs( 1552 self.mixed_precision.param_dtype, 1553 *args, 1554 **kwargs, 1555 ) 1556 return args, kwargs 1557 else: 1558 # Cast inputs to reduced precision if needed. 1559 # TODO (rohan-varma) test this codepath. 1560 if self.mixed_precision is not None: 1561 inputs, kwargs = _cast_forward_inputs( 1562 self.mixed_precision.param_dtype, 1563 *inputs, 1564 **kwargs, 1565 ) 1566 return inputs, kwargs 1567 1568 def _post_forward(self, output): 1569 if self._should_disable_cpp_reducer(): 1570 return output 1571 1572 if self._delay_all_reduce_all_params: 1573 self._clear_grad_buffer() 1574 return output 1575 1576 # sync params according to location (before/after forward) user 1577 # specified as part of hook, if hook was specified. 1578 if self._check_sync_bufs_post_fwd(): 1579 self._sync_buffers() 1580 1581 if torch.is_grad_enabled() and self.require_backward_grad_sync: 1582 self.require_forward_param_sync = True 1583 # We'll return the output object verbatim since it is a freeform 1584 # object. We need to find any tensors in this object, though, 1585 # because we need to figure out which parameters were used during 1586 # this forward pass, to ensure we short circuit reduction for any 1587 # unused parameters. Only if `find_unused_parameters` is set. 1588 if self.find_unused_parameters and not self.static_graph: 1589 # Do not need to populate this for static graph. 1590 self.reducer.prepare_for_backward(list(_find_tensors(output))) 1591 else: 1592 self.reducer.prepare_for_backward([]) 1593 else: 1594 self.require_forward_param_sync = False 1595 1596 # TODO: DDPSink is currently enabled for unused parameter detection and 1597 # static graph training for first iteration. 1598 if (self.find_unused_parameters and not self.static_graph) or ( 1599 self.static_graph and not self._static_graph_delay_allreduce_enqueued 1600 ): 1601 ( 1602 output_tensor_list, 1603 treespec, 1604 output_is_rref, 1605 ) = _tree_flatten_with_rref(output) 1606 output_placeholders: List[Optional[torch.Tensor]] = [ 1607 None for _ in range(len(output_tensor_list)) 1608 ] 1609 # Do not touch tensors that have no grad_fn, which can cause issues 1610 # such as https://github.com/pytorch/pytorch/issues/60733 1611 for i, output in enumerate(output_tensor_list): 1612 if torch.is_tensor(output) and output.grad_fn is None: 1613 output_placeholders[i] = output 1614 1615 # When find_unused_parameters=True, makes tensors which require grad 1616 # run through the DDPSink backward pass. When not all outputs are 1617 # used in loss, this makes those corresponding tensors receive 1618 # undefined gradient which the reducer then handles to ensure 1619 # param.grad field is not touched and we don't error out. 1620 passthrough_tensor_list = _DDPSink.apply( 1621 weakref.ref(self), 1622 *output_tensor_list, 1623 ) 1624 for i in range(len(output_placeholders)): 1625 if output_placeholders[i] is None: 1626 output_placeholders[i] = passthrough_tensor_list[i] 1627 1628 # Reconstruct output data structure. 1629 output = _tree_unflatten_with_rref( 1630 output_placeholders, treespec, output_is_rref 1631 ) 1632 1633 # At the end of the forward pass, reset the grad buffer and grad views 1634 self._clear_grad_buffer() 1635 return output 1636 1637 def forward(self, *inputs, **kwargs): 1638 with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): 1639 inputs, kwargs = self._pre_forward(*inputs, **kwargs) 1640 output = ( 1641 self.module.forward(*inputs, **kwargs) 1642 if self._delay_all_reduce_all_params 1643 else self._run_ddp_forward(*inputs, **kwargs) 1644 ) 1645 return self._post_forward(output) 1646 1647 def scatter(self, inputs, kwargs, device_ids): 1648 return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 1649 1650 def to_kwargs(self, inputs, kwargs, device_id): 1651 # Kept for BC 1652 return _to_kwargs( 1653 inputs, 1654 kwargs, 1655 torch.device(self.device_type, device_id), 1656 self.use_side_stream_for_tensor_copies, 1657 ) 1658 1659 def gather(self, outputs, output_device): 1660 return gather(outputs, output_device, dim=self.dim) 1661 1662 def train(self, mode=True): 1663 super().train(mode) 1664 return self 1665 1666 # When running in join mode, schedules an allreduce to notify joined ranks 1667 # of whether backwards pass synchronization will run this iteration or not. 1668 def _check_global_requires_backward_grad_sync(self, is_joined_rank): 1669 if not is_joined_rank and self.require_backward_grad_sync: 1670 requires_sync_tensor = torch.ones(1, device=self.device) 1671 else: 1672 requires_sync_tensor = torch.zeros(1, device=self.device) 1673 1674 work = dist.all_reduce( 1675 requires_sync_tensor, group=self.process_group, async_op=True 1676 ) 1677 1678 # (kwen2501) This if condition is a plain translation of previous 1679 # behavior, i.e. in the `is_joined_rank=False` case, `work.wait()` 1680 # is not called and it doesn't care about the result. I am guessing 1681 # that it just wants to fire a matching all-reduce and does not want 1682 # the main stream to wait. 1683 if is_joined_rank: 1684 work.wait() 1685 should_sync_backwards = requires_sync_tensor.item() != 0 1686 return should_sync_backwards 1687 else: 1688 return None # Return value is not/should not be used. 1689 1690 # When running in join mode, checks and performs sync of module buffers if 1691 # the models have buffers that should be synchronized in the forward pass. 1692 def _check_and_sync_module_buffers(self): 1693 if self._check_sync_bufs_pre_fwd(): 1694 authoritative_rank = self._find_common_rank(self._distributed_rank, False) 1695 self._sync_module_buffers(authoritative_rank) 1696 1697 # When running in join model, agrees upon a common rank and broadcast model 1698 # parameters to all other ranks. 1699 def _sync_final_model(self, is_last_joiner): 1700 # Agree upon the process that will be the authoritative model copy. 1701 # The current rank is a candidate for being the authoritative copy if 1702 # is_last_joiner=True. We break ties via picking the larger rank. 1703 self._authoritative_rank = self._find_common_rank( 1704 self._distributed_rank, is_last_joiner 1705 ) 1706 _sync_module_states( 1707 module=self.module, 1708 process_group=self.process_group, 1709 broadcast_bucket_size=self.broadcast_bucket_size, 1710 src=self._authoritative_rank, 1711 params_and_buffers_to_ignore=self.parameters_to_ignore, 1712 broadcast_buffers=self.broadcast_buffers, 1713 ) 1714 1715 # Schedule comm ops to match those scheduled in the reducer's backward 1716 # pass. 1717 def _match_all_reduce_for_bwd_pass(self): 1718 comm_work = [] 1719 # Schedule comm in the same order as Reducer schedules them, i.e. 1720 # the order of the buckets. Retrieving the bucket order from the reducer 1721 # ensures that we keep the same order in join mode, such as when bucket 1722 # order is rebuilt dynamically. 1723 1724 # Returns grad_buckets in order, but real tensors are substituted with 1725 # zero tensors of the same shape. 1726 grad_buckets = self.reducer._get_zeros_like_grad_buckets() 1727 for grad_bucket in grad_buckets: 1728 # Joined processes contribute zero gradient. In the case that 1729 # divide_by_initial_world_size=True, we divide grads by the static 1730 # world size, if not, the dividing factor is reduced by the number 1731 # of joined processes. 1732 work = self.reducer._run_comm_hook(grad_bucket) 1733 comm_work.append(work) 1734 for work in comm_work: 1735 work.wait() 1736 1737 # Allreduces the used parameter mapping across ranks. 1738 def _match_unused_params_allreduce(self): 1739 locally_used_param_map = self.reducer._get_local_used_map() 1740 self.process_group.allreduce(locally_used_param_map) 1741 1742 def join( 1743 self, 1744 divide_by_initial_world_size: bool = True, 1745 enable: bool = True, 1746 throw_on_early_termination: bool = False, 1747 ): 1748 r""" 1749 Context manager for training with uneven inputs across processes in DDP. 1750 1751 This context manager will keep track of already-joined DDP processes, 1752 and "shadow" the forward and backward passes by inserting collective 1753 communication operations to match with the ones created by non-joined 1754 DDP processes. This will ensure each collective call has a corresponding 1755 call by already-joined DDP processes, preventing hangs or errors that 1756 would otherwise happen when training with uneven inputs across 1757 processes. Alternatively, if the flag ``throw_on_early_termination`` is 1758 specified to be ``True``, all trainers will throw an error once one rank 1759 runs out of inputs, allowing these errors to be caught and handled 1760 according to application logic. 1761 1762 Once all DDP processes have joined, the context manager will broadcast 1763 the model corresponding to the last joined process to all processes to 1764 ensure the model is the same across all processes 1765 (which is guaranteed by DDP). 1766 1767 To use this to enable training with uneven inputs across processes, 1768 simply wrap this context manager around your training loop. No further 1769 modifications to the model or data loading is required. 1770 1771 .. warning:: 1772 If the model or training loop this context manager is wrapped around 1773 has additional distributed collective operations, such as 1774 ``SyncBatchNorm`` in the model's forward pass, then the flag 1775 ``throw_on_early_termination`` must be enabled. This is because this 1776 context manager is not aware of non-DDP collective communication. 1777 This flag will cause all ranks to throw when any one rank 1778 exhausts inputs, allowing these errors to be caught and recovered 1779 from across all ranks. 1780 1781 Args: 1782 divide_by_initial_world_size (bool): If ``True``, will divide 1783 gradients by the initial ``world_size`` DDP training was launched 1784 with. If ``False``, will compute the effective world size 1785 (number of ranks that have not depleted their inputs yet) and 1786 divide gradients by that during allreduce. Set 1787 ``divide_by_initial_world_size=True`` to ensure every input 1788 sample including the uneven inputs have equal weight in terms of 1789 how much they contribute to the global gradient. This is 1790 achieved by always dividing the gradient by the initial 1791 ``world_size`` even when we encounter uneven inputs. If you set 1792 this to ``False``, we divide the gradient by the remaining 1793 number of nodes. This ensures parity with training on a smaller 1794 ``world_size`` although it also means the uneven inputs would 1795 contribute more towards the global gradient. Typically, you 1796 would want to set this to ``True`` for cases where the last few 1797 inputs of your training job are uneven. In extreme cases, where 1798 there is a large discrepancy in the number of inputs, setting 1799 this to ``False`` might provide better results. 1800 enable (bool): Whether to enable uneven input detection or not. Pass 1801 in ``enable=False`` to disable in cases where you know that 1802 inputs are even across participating processes. Default is 1803 ``True``. 1804 throw_on_early_termination (bool): Whether to throw an error 1805 or continue training when at least one rank has exhausted 1806 inputs. If ``True``, will throw upon the first rank reaching end 1807 of data. If ``False``, will continue training with a smaller 1808 effective world size until all ranks are joined. Note that if 1809 this flag is specified, then the flag 1810 ``divide_by_initial_world_size`` would be ignored. Default 1811 is ``False``. 1812 1813 1814 Example:: 1815 1816 >>> # xdoctest: +SKIP("Distributed") 1817 >>> import torch 1818 >>> import torch.distributed as dist 1819 >>> import os 1820 >>> import torch.multiprocessing as mp 1821 >>> import torch.nn as nn 1822 >>> # On each spawned worker 1823 >>> def worker(rank): 1824 >>> dist.init_process_group("nccl", rank=rank, world_size=2) 1825 >>> torch.cuda.set_device(rank) 1826 >>> model = nn.Linear(1, 1, bias=False).to(rank) 1827 >>> model = torch.nn.parallel.DistributedDataParallel( 1828 >>> model, device_ids=[rank], output_device=rank 1829 >>> ) 1830 >>> # Rank 1 gets one more input than rank 0. 1831 >>> inputs = [torch.tensor([1]).float() for _ in range(10 + rank)] 1832 >>> with model.join(): 1833 >>> for _ in range(5): 1834 >>> for inp in inputs: 1835 >>> loss = model(inp).sum() 1836 >>> loss.backward() 1837 >>> # Without the join() API, the below synchronization will hang 1838 >>> # blocking for rank 1's allreduce to complete. 1839 >>> torch.cuda.synchronize(device=rank) 1840 """ 1841 return Join( 1842 [self], 1843 enable, 1844 throw_on_early_termination, 1845 divide_by_initial_world_size=divide_by_initial_world_size, 1846 ) 1847 1848 def join_hook( 1849 self, 1850 **kwargs, 1851 ): 1852 r""" 1853 DDP join hook enables training on uneven inputs by mirroring communications in forward and backward passes. 1854 1855 Arguments: 1856 kwargs (dict): a :class:`dict` containing any keyword arguments 1857 to modify the behavior of the join hook at run time; all 1858 :class:`Joinable` instances sharing the same join context 1859 manager are forwarded the same value for ``kwargs``. 1860 1861 The hook supports the following keyword arguments: 1862 divide_by_initial_world_size (bool, optional): 1863 If ``True``, then gradients are divided by the initial world 1864 size that DDP was launched with. 1865 If ``False``, then gradients are divided by the effective world 1866 size (i.e. the number of non-joined processes), meaning that 1867 the uneven inputs contribute more toward the global gradient. 1868 Typically, this should be set to ``True`` if the degree of 1869 unevenness is small but can be set to ``False`` in extreme 1870 cases for possibly better results. 1871 Default is ``True``. 1872 """ 1873 divide_by_initial_world_size = kwargs.get("divide_by_initial_world_size", True) 1874 return _DDPJoinHook( 1875 self, divide_by_initial_world_size=divide_by_initial_world_size 1876 ) 1877 1878 @property 1879 def join_device(self): 1880 return self.device 1881 1882 @property 1883 def join_process_group(self): 1884 return self.process_group 1885 1886 def _register_buffer_comm_hook( 1887 self, 1888 state, 1889 hook: Callable, 1890 comm_hook_location=_BufferCommHookLocation.POST_FORWARD, 1891 ): 1892 r""" 1893 Allow custom registration of hooks that define how buffer are synchronized across ranks. 1894 1895 The hook takes in an optional state and is passed in a Dict[str, Tensor] 1896 corresponding to buffer names and the buffers, and can run arbitrary reductions 1897 on buffers as opposed to DDP's default broadcast from rank 0. This is useful for 1898 example if a counter needs to be summed or averaged across ranks every iteration. 1899 1900 Args: 1901 state (Any): Optional state that is passed to the hook. 1902 hook (Callable): Callable with the following signature: 1903 ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]`` 1904 comm_hook_location (_BufferCommHookLocation): Enum value indicating 1905 where to run the hook. 1906 _BufferCommHookLocation.PRE_FORWARD means that the 1907 hook will run _before_ the forward pass, and 1908 _BufferCommHookLocation.POST_FORWARD means that the 1909 hook will run _after_ the forward pass. 1910 1911 NOTE: To maximize performance, users can return a 1912 List[torch.futures.Future] from their hook, and DDP will 1913 install and await these hooks appropriately at the end of 1914 the backward pass. This will ensure all buffers are 1915 synchronized by the end of the backward pass. If this 1916 setting is used, it is recommended to pass 1917 comm_hook_location=_BufferCommHookLocation.POST_FORWARD, 1918 which will trigger the hook after the forward pass. 1919 If _BufferCommHookLocation.PRE_FORWARD is used, users must 1920 ensure appropriate synchronization when manipulating GPU 1921 buffers in the forward pass. 1922 """ 1923 assert callable(hook) 1924 self.buffer_hook = _BufferCommHook( 1925 buffer_comm_hook=hook, 1926 buffer_comm_hook_state=state, 1927 buffer_comm_hook_location=comm_hook_location, 1928 ) 1929 1930 def register_comm_hook(self, state: object, hook: Callable): 1931 r""" 1932 Register communication hook for user-defined DDP aggregation of gradients across multiple workers. 1933 1934 This hook would be very useful for researchers to try out new ideas. For 1935 example, this hook can be used to implement several algorithms like GossipGrad 1936 and gradient compression which involve different communication strategies for 1937 parameter syncs while running Distributed DataParallel training. 1938 1939 Args: 1940 state (object): Passed to the hook to maintain any state information during the training process. 1941 Examples include error feedback in gradient compression, 1942 peers to communicate with next in GossipGrad, etc. 1943 1944 It is locally stored by each worker 1945 and shared by all the gradient tensors on the worker. 1946 hook (Callable): Callable with the following signature: 1947 ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``: 1948 1949 This function is called once the bucket is ready. The 1950 hook can perform whatever processing is needed and return 1951 a Future indicating completion of any async work (ex: allreduce). 1952 If the hook doesn't perform any communication, it still 1953 must return a completed Future. The Future should hold the 1954 new value of grad bucket's tensors. Once a bucket is ready, 1955 c10d reducer would call this hook and use the tensors returned 1956 by the Future and copy grads to individual parameters. 1957 Note that the future's return type must be a single tensor. 1958 1959 We also provide an API called ``get_future`` to retrieve a 1960 Future associated with the completion of ``c10d.ProcessGroup.Work``. 1961 ``get_future`` is currently supported for NCCL and also supported for most 1962 operations on GLOO and MPI, except for peer to peer operations (send/recv). 1963 1964 .. warning :: 1965 Grad bucket's tensors will not be predivided by world_size. User is responsible 1966 to divide by the world_size in case of operations like allreduce. 1967 1968 .. warning :: 1969 DDP communication hook can only be registered once and should be registered 1970 before calling backward. 1971 1972 .. warning :: 1973 The Future object that hook returns should contain a single tensor 1974 that has the same shape with the tensors inside grad bucket. 1975 1976 .. warning :: 1977 ``get_future`` API supports NCCL, and partially GLOO and MPI backends (no support 1978 for peer-to-peer operations like send/recv) and will return a ``torch.futures.Future``. 1979 1980 Example:: 1981 Below is an example of a noop hook that returns the same tensor. 1982 1983 >>> # xdoctest: +SKIP('undefined name') 1984 >>> def noop(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: 1985 >>> fut = torch.futures.Future() 1986 >>> fut.set_result(bucket.buffer()) 1987 >>> return fut 1988 >>> ddp.register_comm_hook(state=None, hook=noop) 1989 1990 Example:: 1991 Below is an example of a Parallel SGD algorithm where gradients are encoded before 1992 allreduce, and then decoded after allreduce. 1993 1994 >>> # xdoctest: +SKIP('undefined name') 1995 >>> def encode_and_decode(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: 1996 >>> encoded_tensor = encode(bucket.buffer()) # encode gradients 1997 >>> fut = torch.distributed.all_reduce(encoded_tensor).get_future() 1998 >>> # Define the then callback to decode. 1999 >>> def decode(fut): 2000 >>> decoded_tensor = decode(fut.value()[0]) # decode gradients 2001 >>> return decoded_tensor 2002 >>> return fut.then(decode) 2003 >>> ddp.register_comm_hook(state=None, hook=encode_and_decode) 2004 """ 2005 self._check_comm_hook(hook) 2006 assert self.logger is not None 2007 self.logger._set_comm_hook_name(hook.__qualname__) 2008 self._comm_hooks.append((hook, state)) 2009 dist._register_comm_hook(self.reducer, state, hook) 2010 2011 def _register_builtin_comm_hook(self, comm_hook_type): 2012 r""" 2013 Register a built-in communication hook that specifies how DDP aggregates gradients across multiple workers. 2014 2015 The built-in hooks aim to provide efficient C++ implementations for certain hooks, 2016 which might not be as efficient if implemented in Python using a Python communication hook. 2017 2018 Args: 2019 comm_hook_type (dist.BuiltinCommHookType): type of communication hook, such as ALLREDUCE, FP16_COMPRESS, etc. 2020 2021 .. warning :: 2022 DDP communication hook can only be registered once and should be registered 2023 before calling backward. 2024 2025 Example:: 2026 Below is an example of a FP16 compression where gradients are 2027 compressed into 16-bit floating-point numbers before allreduce, and 2028 then decompressed after allreduce. 2029 2030 >>> # xdoctest: +SKIP('undefined name') 2031 >>> ddp._register_builtin_comm_hook(dist.BuiltinCommHookType.FP16_COMPRESS) 2032 2033 """ 2034 assert self.logger is not None 2035 self.logger._set_comm_hook_name(str(comm_hook_type)) 2036 dist._register_builtin_comm_hook(self.reducer, comm_hook_type) 2037 2038 def _register_fused_optim(self, optim: Type, *args, optim_params=None, **kwargs): 2039 r""" 2040 Register an optimizer in DDP to optimize parameter immediately after its gradient reduction. 2041 2042 Registers an optimizer with DDP such that the optimization for a 2043 parameter will run immediately when that parameter's gradient is 2044 finished with reduction, instead of waiting for all parameters' 2045 gradients to finish reduction. This can result in a training speedup 2046 depending on your workload since the optimizer can run while gradient 2047 reduction for other parameters are still ongoing. In addition, this has 2048 the potential to reduce peak memory consumption during training, as it 2049 only needs to load the per-parameter optimizer states of a single 2050 parameter at a time, instead of loading all per-parameter optimizer 2051 states at once. 2052 2053 Args: 2054 optim (Type): a ``torch.optim.Optimizer`` class to be registered 2055 as a fused optimizer. 2056 *args (Sequence[Any]): Arguments to forward to `optim`. 2057 optim_params (Optional[Iterable[torch.Tensor]]): Set of parameters 2058 to optimize, similar to `params` argument of traditional `torch.optim` 2059 Optimizers. If this is omitted, all DDP model parameters will be 2060 optimized. 2061 **kwargs: (Dict[str, Any]): Keyword arguments to forward to `optim`. 2062 2063 .. warning :: 2064 _register_fused_optim should only be called once on a DDP instance, 2065 and registering multiple fused optimizers for the same DDP model 2066 is not currently supported. Please ping 2067 https://github.com/pytorch/pytorch/issues/71595 if this is necessary 2068 for your use case. 2069 2070 .. warning :: 2071 _register_fused_optim and register_comm_hook currently do not 2072 compose together, meaning that custom DDP communication hooks are 2073 not supported with overlapped optimizers. Please ping 2074 https://github.com/pytorch/pytorch/issues/71595 if this is necessary 2075 for your use case. 2076 2077 .. warning :: 2078 Gradient accumulation and DDP `no_sync` are currently not supported 2079 with overlapped optimizer. Please ping 2080 https://github.com/pytorch/pytorch/issues/71595 if this is necessary 2081 for your use case. 2082 2083 Example:: 2084 2085 >>> # xdoctest: +SKIP("No rendezvous handler") 2086 >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...') 2087 >>> net = torch.nn.parallel.DistributedDataParallel(model, pg) 2088 >>> lr = 1e-2 2089 >>> betas = (0.9, 0.99) 2090 >>> eps = 1e-6 2091 >>> net._register_fused_optim(torch.optim.Adam, lr, betas=betas, eps=eps) 2092 >>> # Example with subset of parameters 2093 >>> params_to_opt = [list(net.parameters())[0]] 2094 >>> net._register_fused_optim( 2095 ... torch.optim.Adam, lr, optim_params=params_to_opt, betas=betas, eps=eps 2096 ... ) 2097 """ 2098 # Note: importing in function, otherwise this will cause a circular 2099 # import as optimizer_overlap module needs to import DistributedDataParallel. 2100 from torch.distributed.algorithms._optimizer_overlap import _as_overlapped_optim 2101 2102 overlapped_optim = _as_overlapped_optim(optim, optim_params, *args, **kwargs) 2103 try: 2104 overlapped_optim.register_ddp(self) 2105 except NotImplementedError as e: 2106 raise RuntimeError( 2107 f"{optim} does not support overlapped DDP. Please file an issue to PyTorch or the respective owner of {optim}." 2108 ) from e 2109 2110 def _distributed_broadcast_coalesced( 2111 self, tensors, buffer_size, authoritative_rank=0 2112 ): 2113 dist._broadcast_coalesced( 2114 self.process_group, tensors, buffer_size, authoritative_rank 2115 ) 2116 2117 def _check_sync_bufs_post_fwd(self): 2118 return ( 2119 self.will_sync_module_buffers() 2120 and hasattr(self, "buffer_hook") 2121 and self.buffer_hook.buffer_comm_hook_location 2122 == _BufferCommHookLocation.POST_FORWARD 2123 ) 2124 2125 def _check_sync_bufs_pre_fwd(self): 2126 return self.will_sync_module_buffers() and ( 2127 not hasattr(self, "buffer_hook") 2128 or self.buffer_hook.buffer_comm_hook_location 2129 == _BufferCommHookLocation.PRE_FORWARD 2130 ) 2131 2132 def will_sync_module_buffers(self): 2133 return ( 2134 self.require_forward_param_sync 2135 and self.broadcast_buffers 2136 and len(self.modules_buffers) > 0 2137 ) 2138 2139 def _find_common_rank(self, input_rank, rank_cond): 2140 # -1 indicates that this rank is not under consideration to be the 2141 # common_rank 2142 rank_to_use = torch.tensor( 2143 [input_rank if rank_cond else -1], 2144 device=self.device, 2145 ) 2146 dist.all_reduce(rank_to_use, op=ReduceOp.MAX, group=self.process_group) 2147 if rank_to_use.item() == -1: 2148 self._log_and_throw( 2149 ValueError, 2150 "BUG! Expected rank_cond to be true for at least one process." 2151 " This indicates a bug in PyTorch, please report an issue.", 2152 ) 2153 return rank_to_use.item() 2154 2155 def _sync_buffers(self): 2156 with torch.no_grad(): 2157 # module buffer sync 2158 # Synchronize buffers across processes. 2159 # If we are running DDP with the join manager, we have to agree 2160 # upon a rank to sync module buffers from, since rank 0 may 2161 # already have been joined and have stale module buffers. 2162 if self._join_config.enable: 2163 authoritative_rank = self._find_common_rank( 2164 self._distributed_rank, True 2165 ) 2166 else: 2167 # The process with rank 0 is considered the authoritative copy. 2168 authoritative_rank = 0 2169 # Update self.modules_buffers incase any buffers were 2170 # reassigned. 2171 self._assign_modules_buffers() 2172 self._sync_module_buffers(authoritative_rank) 2173 2174 def _sync_module_buffers(self, authoritative_rank): 2175 if not hasattr(self, "buffer_hook"): 2176 self._default_broadcast_coalesced(authoritative_rank=authoritative_rank) 2177 else: 2178 hook = self.buffer_hook.buffer_comm_hook 2179 state = self.buffer_hook.buffer_comm_hook_state 2180 futs = hook(state, self.named_module_buffers) 2181 if futs is not None: 2182 self.reducer._install_post_backward_futures(futs) 2183 2184 def _default_broadcast_coalesced( 2185 self, bufs=None, bucket_size=None, authoritative_rank=0 2186 ): 2187 """ 2188 Broadcasts buffers from rank 0 to rest of workers. 2189 2190 If bufs, bucket_size are None, default values self.modules_buffers 2191 and self.broadcast_bucket_size are used instead. 2192 """ 2193 if bufs is None: 2194 bufs = self.modules_buffers 2195 if bucket_size is None: 2196 bucket_size = self.broadcast_bucket_size 2197 2198 self._distributed_broadcast_coalesced(bufs, bucket_size, authoritative_rank) 2199 2200 def _passing_sync_batchnorm_handle(self, module): 2201 for layer in module.modules(): 2202 if isinstance(layer, torch.nn.modules.SyncBatchNorm): 2203 if self.device_type == "cpu": 2204 self._log_and_throw( 2205 ValueError, 2206 "SyncBatchNorm layers only work with GPU modules", 2207 ) 2208 2209 def _check_comm_hook(self, hook): 2210 if not callable(hook): 2211 self._log_and_throw(TypeError, "Communication hook must be callable.") 2212 2213 sig = inspect.signature(hook) 2214 if ( 2215 sig.parameters["bucket"].annotation != inspect._empty 2216 and sig.parameters["bucket"].annotation != dist.GradBucket 2217 ): 2218 self._log_and_throw( 2219 ValueError, 2220 "Communication hook: bucket annotation should be dist.GradBucket.", 2221 ) 2222 2223 if ( 2224 sig.return_annotation != inspect._empty 2225 and sig.return_annotation != torch.futures.Future[torch.Tensor] 2226 ): 2227 self._log_and_throw( 2228 ValueError, 2229 "Communication hook: return annotation should be torch.futures.Future[torch.Tensor].", 2230 ) 2231 2232 if hook.__name__ in [ 2233 "bf16_compress_hook", 2234 "bf16_compress_wrapper_hook", 2235 ] and ( 2236 (torch.version.cuda is None and torch.version.hip is None) 2237 or ( 2238 torch.version.cuda is not None 2239 and int(torch.version.cuda.split(".")[0]) < 11 2240 ) 2241 or not dist.is_available() 2242 or not dist.is_nccl_available() 2243 or torch.cuda.nccl.version() < (2, 10) 2244 ): 2245 self._log_and_throw( 2246 TypeError, 2247 "BF16 all reduce communication hook required CUDA 11+ and NCCL 2.10+.", 2248 ) 2249 2250 @property 2251 def _distributed_rank(self): 2252 return dist.get_rank(self.process_group) 2253 2254 @staticmethod 2255 def _get_data_parallel_params(module, named_params=False): 2256 """Return a generator of parameters managed by a given DDP unit.""" 2257 for param in ( 2258 module.parameters() if not named_params else module.named_parameters() 2259 ): 2260 if not hasattr(param, "_ddp_ignored"): 2261 yield param 2262 2263 @staticmethod 2264 def _set_params_and_buffers_to_ignore_for_model( 2265 module, params_and_buffers_to_ignore 2266 ): 2267 """ 2268 Set parameters and buffers to be ignored by DDP. 2269 2270 Expected format for parameters is the fully qualified name: {module_name}.{param_name}, and 2271 similarly, {module_name}.{buffer_name} for buffers. For example: 2272 params_to_ignore = [] 2273 # NB: model here is vanilla PyTorch module, not yet wrapped with DDP. 2274 for module_name, module in model.named_modules(): 2275 for param_name, param in module.named_parameters(recurse=False): 2276 if should_ignore(param): 2277 # Create expected format 2278 fqn = f"{module_name}.{param_name}" 2279 params_to_ignore.append(fqn) 2280 torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model( 2281 model, 2282 params_to_ignore 2283 ) 2284 """ 2285 # This is a workaround to set parameters and buffers DDP should ignore 2286 # during synchronization. It will be removed when the API is finalized 2287 # as part of addressing https://github.com/pytorch/pytorch/issues/43690. 2288 module._ddp_params_and_buffers_to_ignore = params_and_buffers_to_ignore 2289 for name, param in module.named_parameters(): 2290 if name in params_and_buffers_to_ignore: 2291 param._ddp_ignored = True 2292 for name, buffer in module.named_buffers(): 2293 if name in params_and_buffers_to_ignore: 2294 buffer._ddp_ignored = True 2295 2296 def _get_ddp_logging_data(self): 2297 r""" 2298 Return a dictionary of logging data for debugging and analysis. 2299 2300 This interface can be called after DistributedDataParallel() is 2301 constructed. It returns a dictionary of logging data. It could help 2302 for debugging and analysis. The logging data includes DistributedDataParallel 2303 constructor input parameters, some internal states of DistributedDataParallel 2304 and performance metrics. Simply print the dictionary and see what 2305 these metrics are. 2306 This is a prototype interface and subject to change in the future. 2307 """ 2308 assert self.logger is not None 2309 ddp_logging_data = self.logger._get_ddp_logging_data() 2310 return {**ddp_logging_data.strs_map, **ddp_logging_data.ints_map} 2311 2312 def _set_ddp_runtime_logging_sample_rate(self, sample_rate): 2313 r""" 2314 Set sample_rate of collecting runtime stats. 2315 2316 This interface allows users to set sample_rate of collecting 2317 runtime stats. The runtime stats will be recorded for the 2318 first 10 iterations, after 10 iterations runtime stats will be 2319 recorded once every "sample_rate" training iterations. In 2320 default, runtime stats are recorded for the first 10 iterations, 2321 after 10 iterations runtime stats are recorded once every 2322 "kDDPRuntimeLoggingSampleRate=100" training iterations. 2323 This is a prototype interface and subject to change in the future. 2324 """ 2325 if sample_rate < 1: 2326 self._log_and_throw( 2327 ValueError, 2328 "DDP runtime logging sample rate should be equal or greater than 1", 2329 ) 2330 self.reducer._set_ddp_runtime_logging_sample_rate(sample_rate) 2331 2332 def _set_static_graph(self): 2333 """ 2334 Set static graph for DDP. 2335 2336 It is recommended to set static graph in the DDP constructor, which will 2337 call this private API internally. 2338 """ 2339 # If self.static_graph has been set, no need to set it again 2340 if self.static_graph: 2341 warnings.warn( 2342 "You've set static_graph to be True, no need to set it again." 2343 ) 2344 return 2345 self.static_graph = True 2346 self._static_graph_delay_allreduce_enqueued = False 2347 self.reducer._set_static_graph() 2348 assert self.logger is not None 2349 self.logger._set_static_graph() 2350 if self.find_unused_parameters: 2351 warnings.warn( 2352 "You passed find_unused_parameters=true to DistributedDataParallel, " 2353 "`_set_static_graph` will detect unused parameters automatically, so " 2354 "you do not need to set find_unused_parameters=true, just be sure these " 2355 "unused parameters will not change during training loop while calling " 2356 "`_set_static_graph`." 2357 ) 2358 2359 def _remove_autograd_hooks(self): 2360 """Remove autograd hooks registered by the reducer on the model parameters.""" 2361 self.reducer._remove_autograd_hooks() 2362 2363 def _check_reducer_finalized(self): 2364 """ 2365 Check if the reducer has processed all buckets and finalized the backward appropriately. 2366 2367 It is useful to call this method after calling .backward() in your training loop 2368 in order to avoid subsequent hard to debug errors down the road due to the 2369 reducer not finalizing backward. 2370 """ 2371 self.reducer._check_reducer_finalized() 2372 2373 def _set_sparse_metadata(self, global_unique_ids): 2374 self.reducer._set_sparse_metadata(global_unique_ids) 2375 2376 def _update_process_group(self, new_process_group): 2377 """ 2378 Dynamically updates the process group for DDP so that we can shrink/expand DDP 2379 world size without having to reinitialize DDP. 2380 2381 NOTE: If you are using custom communications hooks via, register_comm_hook, 2382 you need to update the process groups for those hooks separately. 2383 """ 2384 # Force a rebuild of buckets for a new process group. This ensures all ranks 2385 # are synchronized in terms of when they will rebuild buckets and also 2386 # re-evaluates previous assumptions of buckets given the world size might have 2387 # changed. 2388 self._has_rebuilt_buckets = False 2389 self.reducer._reset_state() 2390 2391 if not _rank_not_in_group(new_process_group): 2392 self.process_group = new_process_group 2393 self.reducer._update_process_group(new_process_group) 2394 2395 def _set_ddp_sink_clone(self, val: bool): 2396 """ 2397 Sets whether or not DDPSink should clone the output tensors or not. 2398 The default is True since if the loss is modified in place we run 2399 into the view is modified in-place error. 2400 2401 Although, cloning the tensors can add significant memory and 2402 performance hit if the number and size of tensors are large. As 2403 a result, this can be set to False if you are not modifying the 2404 loss in place. 2405 """ 2406 self._ddp_sink_clone = val 2407