xref: /aosp_15_r20/external/pytorch/torch/nn/parallel/distributed.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import functools
4import inspect
5import itertools
6import logging
7import os
8import sys
9import warnings
10import weakref
11from collections import defaultdict, deque
12from contextlib import contextmanager
13from dataclasses import dataclass, fields, is_dataclass
14from enum import auto, Enum
15from typing import Any, Callable, List, Optional, Tuple, Type, TYPE_CHECKING
16
17import torch
18import torch.distributed as dist
19from torch._utils import _get_device_index
20from torch.autograd import Function, Variable
21from torch.distributed.algorithms.join import Join, Joinable, JoinHook
22from torch.nn.modules import Module
23from torch.nn.parallel.scatter_gather import gather, scatter_kwargs
24from torch.utils._pytree import tree_flatten, tree_unflatten
25
26
27RPC_AVAILABLE = False
28if dist.is_available():
29    from torch.distributed.distributed_c10d import (
30        _get_default_group,
31        _rank_not_in_group,
32        ReduceOp,
33    )
34    from torch.distributed.utils import (
35        _alloc_storage,
36        _cast_forward_inputs,
37        _free_storage,
38        _sync_module_states,
39        _to_kwargs,
40        _verify_param_shape_across_processes,
41    )
42if dist.rpc.is_available():
43    RPC_AVAILABLE = True
44    from torch.distributed.rpc import RRef
45
46if TYPE_CHECKING:
47    from torch.utils.hooks import RemovableHandle
48
49
50__all__ = ["DistributedDataParallel"]
51
52logger = logging.getLogger(__name__)
53
54
55@dataclass
56class _MixedPrecision:
57    """
58    This configures DDP-native mixed precision training.
59
60    Attributes:
61        param_dtype (torch.dtype): This specifies the dtype for model
62            parameters, inputs (when ``cast_forward_inputs`` is set to
63            ``True``), and therefore the dtype for computation.
64            However, outside the forward and backward passes, parameters are in
65            full precision. Model checkpointing always happens in full
66            precision.
67        reduce_dtype (torch.dtype): This specifies the dtype for gradient
68            reduction, which is permitted to differ from ``param_dtype``.
69        buffer_dtype (torch.dtype): This specifies the dtype for buffers.
70
71    .. note:: This API is experimental and subject to change.
72
73    .. note:: Only floating point tensors are cast to their specified dtypes.
74
75    .. note:: ``state_dict`` checkpoints parameters and buffers in full
76        precision.
77
78    .. note:: Each low precision dtype must be specified explicitly. For
79        example, ``_MixedPrecision(reduce_dtype=torch.float16)`` only specifies
80        the reduction dtype to be low precision, and DDP will not cast
81        parameters or buffers.
82
83    .. note:: If a ``reduce_dtype`` is not specified, then gradient reduction
84        happens in ``param_dtype`` if specified or the original parameter dtype
85        otherwise. For example, ``_MixedPrecision(param_dtype=torch.float16)``
86        would result in communication occurring in fp16.
87    """
88
89    param_dtype: Optional[torch.dtype] = None
90    reduce_dtype: Optional[torch.dtype] = None
91    buffer_dtype: Optional[torch.dtype] = None
92    # TODO (rohan-varma): keep_low_precision_grads: bool = False
93    # TODO (rohan-varma): APIs to allow users to run batchnorm and layernorm
94    # in full precision. For DDP, this can be implemented by not performing the
95    # parameter cast for BN and LN units.
96
97
98def _cast_buffers(mixed_precision_config, root_module):
99    """Casts buffers to the given ``buffer_dtype``."""
100    for buf in root_module.buffers():
101        if hasattr(buf, "_ddp_ignored") and buf._ddp_ignored:
102            continue
103
104        buf.data = buf.to(dtype=mixed_precision_config.buffer_dtype)
105
106
107def _setup_mixed_precision_params(mixed_precision_config, root_module):
108    """Create and free storage for the mixed precision parameters."""
109    for param in root_module.parameters():
110        # Do not setup mixed precision for DDP ignored parameters.
111        if hasattr(param, "_ddp_ignored") and param._ddp_ignored:
112            continue
113
114        if not hasattr(param, "_mp_param"):
115            param._mp_param = torch.zeros_like(
116                param,
117                device=param.device,
118                dtype=mixed_precision_config.param_dtype,
119                requires_grad=param.requires_grad,
120            )
121            _free_storage(param._mp_param)
122            # _fp_param will point to the full precision param so it can be switched
123            # back to at the end of forward / backward.
124            param._fp_param = param.data
125
126
127def _tree_flatten_with_rref(output):
128    output_is_rref = RPC_AVAILABLE and isinstance(output, RRef)
129    if output_is_rref:
130        output_tensor_list, treespec = tree_flatten(output.local_value())
131    else:
132        output_tensor_list, treespec = tree_flatten(output)
133    # Need to return flattened tensors, spec to re-pack them, as well
134    # as if the return type was actually an RRef to reconstruct.
135    return output_tensor_list, treespec, output_is_rref
136
137
138def _tree_unflatten_with_rref(output, treespec, output_is_rref):
139    output = tree_unflatten(output, treespec)
140    if output_is_rref:
141        output = RRef(output)
142    return output
143
144
145def _find_tensors(obj):
146    r"""Recursively find all tensors contained in the specified object."""
147    if RPC_AVAILABLE and isinstance(obj, RRef):
148        # If the current node is the owner of the RRef, unwrap it and try to
149        # find Tensors.
150        # TODO: Expand to remote RRefs.
151        if obj.is_owner():
152            return _find_tensors(obj.local_value())
153    if isinstance(obj, torch.Tensor):
154        return [obj]
155    if isinstance(obj, (list, tuple)):
156        return itertools.chain.from_iterable(map(_find_tensors, obj))
157    if isinstance(obj, dict):
158        return itertools.chain.from_iterable(map(_find_tensors, obj.values()))
159    if is_dataclass(obj):
160        return itertools.chain.from_iterable(
161            map(_find_tensors, (getattr(obj, f.name) for f in fields(obj)))
162        )
163
164    return []
165
166
167def _dump_DDP_relevant_env_vars():
168    relevant_env_vars = [
169        "RANK",
170        "LOCAL_RANK",
171        "WORLD_SIZE",
172        "MASTER_PORT",
173        "MASTER_ADDR",
174        "CUDA_VISIBLE_DEVICES",
175        "GLOO_SOCKET_IFNAME",
176        "GLOO_DEVICE_TRANSPORT",
177        "NCCL_SOCKET_IFNAME",
178        "TORCH_NCCL_BLOCKING_WAIT",
179        "NCCL_DEBUG",
180        "NCCL_DEBUG_SUBSYS",
181        "NCCL_IB_DISABLE",
182        # More NCCL env vars:
183        "NCCL_P2P_DISABLE",
184        "NCCL_P2P_LEVEL",
185        "NCCL_SHM_DISABLE",
186        "NCCL_SOCKET_NTHREADS",
187        "NCCL_NSOCKS_PERTHREAD",
188        "NCCL_BUFFSIZE",
189        "NCCL_NTHREADS",
190        "NCCL_RINGS",
191        "NCCL_MAX_NCHANNELS",
192        "NCCL_MIN_NCHANNELS",
193        "NCCL_CHECKS_DISABLE",
194        "NCCL_CHECK_POINTERS",
195        "NCCL_LAUNCH_MODE",
196        "NCCL_IB_HCA",
197        "NCCL_IB_TIMEOUT",
198        "NCCL_IB_RETRY_CNT",
199        "NCCL_IB_GID_INDEX",
200        "NCCL_IB_SL",
201        "NCCL_IB_TC",
202        "NCCL_IB_AR_THRESHOLD",
203        "NCCL_IB_CUDA_SUPPORT",
204        "NCCL_NET_GDR_LEVEL",
205        "NCCL_NET_GDR_READ",
206        "NCCL_SINGLE_RING_THRESHOLD",
207        "NCCL_LL_THRESHOLD",
208        "NCCL_TREE_THRESHOLD",
209        "NCCL_ALGO",
210        "NCCL_PROTO",
211        "NCCL_IGNORE_CPU_AFFINITY",
212        "NCCL_DEBUG_FILE",
213        "NCCL_COLLNET_ENABLE",
214        "NCCL_TOPO_FILE",
215        "NCCL_TOPO_DUMP_FILE",
216        "TORCH_NCCL_ASYNC_ERROR_HANDLING",
217    ]
218    formatted_output = ""
219    for var in relevant_env_vars:
220        value = os.environ[var] if var in os.environ else "N/A"
221        formatted_output += f"env:{var}={value}\n"
222    print(formatted_output)
223
224
225class _BufferCommHookLocation(Enum):
226    PRE_FORWARD = auto()
227    POST_FORWARD = auto()
228
229
230@dataclass
231class _BufferCommHook:
232    buffer_comm_hook: Callable
233    buffer_comm_hook_state: Any
234    buffer_comm_hook_location: _BufferCommHookLocation
235
236
237# Add a DDPSink to run various functions when backwards starts, such as
238# queueing call back of out-most backward/graph task,
239# this helps call back is fired after all gradients' calculation
240# is completed.
241class _DDPSink(Function):
242    @staticmethod
243    def forward(ctx, ddp_weakref, *inputs):
244        # set_materialize_grads(False) will ensure that None gradients stay as
245        # None and are not filled with zeros.
246        ctx.set_materialize_grads(False)
247        ctx.ddp_weakref = ddp_weakref
248        ret = inputs
249        if ddp_weakref()._ddp_sink_clone:
250            ret = tuple(
251                inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs
252            )
253        return ret
254
255    @staticmethod
256    def backward(ctx, *grad_outputs):
257        # Enqueue delay allreduce for static graph training on the first
258        # iteration.
259        ddp_weakref = ctx.ddp_weakref()
260        reducer = ddp_weakref.reducer
261        static_graph = ddp_weakref.static_graph
262        delay_ar_enqueued = (
263            static_graph and ddp_weakref._static_graph_delay_allreduce_enqueued
264        )
265        if static_graph and not delay_ar_enqueued:
266            Variable._execution_engine.queue_callback(  # type: ignore[call-arg,misc]
267                reducer._delay_all_reduce
268            )
269            ddp_weakref._static_graph_delay_allreduce_enqueued = True
270
271        return (None, *grad_outputs)
272
273
274class _DDPJoinHook(JoinHook):
275    def __init__(self, ddp, divide_by_initial_world_size):
276        """Set config variables for internal usage."""
277        assert isinstance(ddp, DistributedDataParallel), (
278            "DDP join hook requires passing in a DistributedDataParallel "
279            "instance as the state"
280        )
281        assert ddp.logger is not None
282        ddp.logger._set_uneven_input_join()
283        self.ddp = ddp
284        self.ddp._divide_by_initial_world_size = divide_by_initial_world_size
285        super().__init__()
286
287    def main_hook(self):
288        """Shadow the DDP collective communication operations in the forward and backward passes."""
289        ddp = self.ddp
290        # Buckets are rebuilt only once during a training period
291        ddp.reducer._rebuild_buckets()
292
293        # Schedule a broadcast if we are syncing module buffers in the
294        # forward pass
295        # TODO: make DDP uneven inputs context manager support buffer
296        # comm hook (https://github.com/pytorch/pytorch/issues/65436)
297        ddp._check_and_sync_module_buffers()
298
299        # Check if need to sync in the backward pass
300        should_sync_backwards = ddp._check_global_requires_backward_grad_sync(
301            is_joined_rank=True
302        )
303        # Forward parameter sync is disabled in the next iteration if we
304        # are skipping gradient sync this iteration, so set
305        # `require_forward_param_sync` accordingly
306        ddp.require_forward_param_sync = should_sync_backwards
307        if not should_sync_backwards:
308            return
309
310        # Schedule one allreduce per gradient bucket to match the backward
311        # pass allreduce
312        ddp._match_all_reduce_for_bwd_pass()
313
314        # Check if we need to allreduce locally unused parameters
315        if ddp.find_unused_parameters:
316            ddp._match_unused_params_allreduce()
317
318        # Rebuilt parameters are pushed only once during a training period
319        ddp.reducer._push_all_rebuilt_params()
320
321    def post_hook(self, is_last_joiner: bool):
322        """Sync the final model to ensure that the model is the same across all processes."""
323        self.ddp._sync_final_model(is_last_joiner)
324
325
326class DistributedDataParallel(Module, Joinable):
327    r"""Implement distributed data parallelism based on ``torch.distributed`` at module level.
328
329    This container provides data parallelism by synchronizing gradients
330    across each model replica. The devices to synchronize across are
331    specified by the input ``process_group``, which is the entire world
332    by default. Note that ``DistributedDataParallel`` does not chunk or
333    otherwise shard the input across participating GPUs; the user is
334    responsible for defining how to do so, for example through the use
335    of a :class:`DistributedSampler`.
336
337    See also: :ref:`distributed-basics` and :ref:`cuda-nn-ddp-instead`.
338    The same constraints on input as in :class:`torch.nn.DataParallel` apply.
339
340    Creation of this class requires that ``torch.distributed`` to be already
341    initialized, by calling :func:`torch.distributed.init_process_group`.
342
343    ``DistributedDataParallel`` is proven to be significantly faster than
344    :class:`torch.nn.DataParallel` for single-node multi-GPU data
345    parallel training.
346
347    To use ``DistributedDataParallel`` on a host with N GPUs, you should spawn
348    up ``N`` processes, ensuring that each process exclusively works on a single
349    GPU from 0 to N-1. This can be done by either setting
350    ``CUDA_VISIBLE_DEVICES`` for every process or by calling:
351
352        >>> # xdoctest: +SKIP("undefined variables")
353        >>> torch.cuda.set_device(i)
354
355    where i is from 0 to N-1. In each process, you should refer the following
356    to construct this module:
357
358        >>> # xdoctest: +SKIP("undefined variables")
359        >>> torch.distributed.init_process_group(
360        >>>     backend='nccl', world_size=N, init_method='...'
361        >>> )
362        >>> model = DistributedDataParallel(model, device_ids=[i], output_device=i)
363
364    In order to spawn up multiple processes per node, you can use either
365    ``torch.distributed.launch`` or ``torch.multiprocessing.spawn``.
366
367    .. note::
368        Please refer to `PyTorch Distributed Overview <https://pytorch.org/tutorials/beginner/dist_overview.html>`__
369        for a brief introduction to all features related to distributed training.
370
371    .. note::
372        ``DistributedDataParallel`` can be used in conjunction with
373        :class:`torch.distributed.optim.ZeroRedundancyOptimizer` to reduce
374        per-rank optimizer states memory footprint. Please refer to
375        `ZeroRedundancyOptimizer recipe <https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html>`__
376        for more details.
377
378    .. note:: ``nccl`` backend is currently the fastest and highly recommended
379        backend when using GPUs. This applies to both single-node and
380        multi-node distributed training.
381
382    .. note:: This module also supports mixed-precision distributed training.
383        This means that your model can have different types of parameters such
384        as mixed types of ``fp16`` and ``fp32``, the gradient reduction on these
385        mixed types of parameters will just work fine.
386
387    .. note:: If you use ``torch.save`` on one process to checkpoint the module,
388        and ``torch.load`` on some other processes to recover it, make sure that
389        ``map_location`` is configured properly for every process. Without
390        ``map_location``, ``torch.load`` would recover the module to devices
391        where the module was saved from.
392
393    .. note:: When a model is trained on ``M`` nodes with ``batch=N``, the
394        gradient will be ``M`` times smaller when compared to the same model
395        trained on a single node with ``batch=M*N`` if the loss is summed (NOT
396        averaged as usual) across instances in a batch (because the gradients
397        between different nodes are averaged). You should take this into
398        consideration when you want to obtain a mathematically equivalent
399        training process compared to the local training counterpart. But in most
400        cases, you can just treat a DistributedDataParallel wrapped model, a
401        DataParallel wrapped model and an ordinary model on a single GPU as the
402        same (E.g. using the same learning rate for equivalent batch size).
403
404    .. note::
405        Parameters are never broadcast between processes. The module performs
406        an all-reduce step on gradients and assumes that they will be modified
407        by the optimizer in all processes in the same way. Buffers
408        (e.g. BatchNorm stats) are broadcast from the module in process of rank
409        0, to all other replicas in the system in every iteration.
410
411    .. note::
412        If you are using DistributedDataParallel in conjunction with the
413        :ref:`distributed-rpc-framework`, you should always use
414        :meth:`torch.distributed.autograd.backward` to compute gradients and
415        :class:`torch.distributed.optim.DistributedOptimizer` for optimizing
416        parameters.
417
418        Example::
419
420            >>> # xdoctest: +SKIP("undefined variables")
421            >>> import torch.distributed.autograd as dist_autograd
422            >>> from torch.nn.parallel import DistributedDataParallel as DDP
423            >>> import torch
424            >>> from torch import optim
425            >>> from torch.distributed.optim import DistributedOptimizer
426            >>> import torch.distributed.rpc as rpc
427            >>> from torch.distributed.rpc import RRef
428            >>>
429            >>> t1 = torch.rand((3, 3), requires_grad=True)
430            >>> t2 = torch.rand((3, 3), requires_grad=True)
431            >>> rref = rpc.remote("worker1", torch.add, args=(t1, t2))
432            >>> ddp_model = DDP(my_model)
433            >>>
434            >>> # Setup optimizer
435            >>> optimizer_params = [rref]
436            >>> for param in ddp_model.parameters():
437            >>>     optimizer_params.append(RRef(param))
438            >>>
439            >>> dist_optim = DistributedOptimizer(
440            >>>     optim.SGD,
441            >>>     optimizer_params,
442            >>>     lr=0.05,
443            >>> )
444            >>>
445            >>> with dist_autograd.context() as context_id:
446            >>>     pred = ddp_model(rref.to_here())
447            >>>     loss = loss_func(pred, target)
448            >>>     dist_autograd.backward(context_id, [loss])
449            >>>     dist_optim.step(context_id)
450
451    .. note::
452        DistributedDataParallel currently offers limited support for gradient
453        checkpointing with :meth:`torch.utils.checkpoint`.
454        If the checkpoint is done with use_reentrant=False (recommended), DDP
455        will work as expected without any limitations.
456        If, however, the checkpoint is done with use_reentrant=True (the default),
457        DDP will work as expected when there are no unused parameters in the model
458        and each layer is checkpointed at most once (make sure you are not passing
459        `find_unused_parameters=True` to DDP). We currently do not support the
460        case where a layer is checkpointed multiple times, or when there unused
461        parameters in the checkpointed model.
462
463    .. note::
464        To let a non-DDP model load a state dict from a DDP model,
465        :meth:`~torch.nn.modules.utils.consume_prefix_in_state_dict_if_present`
466        needs to be applied to strip the prefix "module." in the DDP state dict before loading.
467
468    .. warning::
469        Constructor, forward method, and differentiation of the output (or a
470        function of the output of this module) are distributed synchronization
471        points. Take that into account in case different processes might be
472        executing different code.
473
474    .. warning::
475        This module assumes all parameters are registered in the model by the
476        time it is created. No parameters should be added nor removed later.
477        Same applies to buffers.
478
479    .. warning::
480        This module assumes all parameters are registered in the model of each
481        distributed processes are in the same order. The module itself will
482        conduct gradient ``allreduce`` following the reverse order of the
483        registered parameters of the model. In other words, it is users'
484        responsibility to ensure that each distributed process has the exact
485        same model and thus the exact same parameter registration order.
486
487    .. warning::
488        This module allows parameters with non-rowmajor-contiguous strides.
489        For example, your model may contain some parameters whose
490        :class:`torch.memory_format` is ``torch.contiguous_format``
491        and others whose format is ``torch.channels_last``.  However,
492        corresponding parameters in different processes must have the
493        same strides.
494
495    .. warning::
496        This module doesn't work with :func:`torch.autograd.grad` (i.e. it will
497        only work if gradients are to be accumulated in ``.grad`` attributes of
498        parameters).
499
500    .. warning::
501        If you plan on using this module with a ``nccl`` backend or a ``gloo``
502        backend (that uses Infiniband), together with a DataLoader that uses
503        multiple workers, please change the multiprocessing start method to
504        ``forkserver`` (Python 3 only) or ``spawn``. Unfortunately
505        Gloo (that uses Infiniband) and NCCL2 are not fork safe, and you will
506        likely experience deadlocks if you don't change this setting.
507
508    .. warning::
509        You should never try to change your model's parameters after wrapping
510        up your model with ``DistributedDataParallel``. Because, when
511        wrapping up your model with ``DistributedDataParallel``, the constructor
512        of ``DistributedDataParallel`` will register the additional gradient
513        reduction functions on all the parameters of the model itself at the
514        time of construction. If you change the model's parameters afterwards,
515        gradient reduction functions no longer match the correct set of
516        parameters.
517
518    .. warning::
519        Using ``DistributedDataParallel`` in conjunction with the
520        :ref:`distributed-rpc-framework` is experimental and subject to change.
521
522    Args:
523        module (Module): module to be parallelized
524        device_ids (list of int or torch.device): CUDA devices.
525                   1) For single-device modules, ``device_ids`` can
526                   contain exactly one device id, which represents the only
527                   CUDA device where the input module corresponding to this process resides.
528                   Alternatively, ``device_ids`` can also be ``None``.
529                   2) For multi-device modules and CPU modules,
530                   ``device_ids`` must be ``None``.
531
532                   When ``device_ids`` is ``None`` for both cases,
533                   both the input data for the forward pass and the actual module
534                   must be placed on the correct device.
535                   (default: ``None``)
536        output_device (int or torch.device): Device location of output for
537                      single-device CUDA modules. For multi-device modules and
538                      CPU modules, it must be ``None``, and the module itself
539                      dictates the output location. (default: ``device_ids[0]``
540                      for single-device modules)
541        broadcast_buffers (bool): Flag that enables syncing (broadcasting)
542                          buffers of the module at beginning of the ``forward``
543                          function. (default: ``True``)
544        process_group: The process group to be used for distributed data
545                       all-reduction. If ``None``, the default process group, which
546                       is created by :func:`torch.distributed.init_process_group`,
547                       will be used. (default: ``None``)
548        bucket_cap_mb: ``DistributedDataParallel`` will bucket parameters into
549                       multiple buckets so that gradient reduction of each
550                       bucket can potentially overlap with backward computation.
551                       :attr:`bucket_cap_mb` controls the bucket size in
552                       MebiBytes (MiB). If ``None``, a default size of 25 MiB
553                       will be used. (default: ``None``)
554        find_unused_parameters (bool): Traverse the autograd graph from all
555                               tensors contained in the return value of the
556                               wrapped module's ``forward`` function. Parameters
557                               that don't receive gradients as part of this
558                               graph are preemptively marked as being ready to
559                               be reduced. In addition, parameters that may have
560                               been used in the wrapped module's ``forward``
561                               function but were not part of loss computation and
562                               thus would also not receive gradients are
563                               preemptively marked as ready to be reduced.
564                               (default: ``False``)
565        check_reduction: This argument is deprecated.
566        gradient_as_bucket_view (bool): When set to ``True``, gradients will be views
567                      pointing to different offsets of ``allreduce`` communication
568                      buckets. This can reduce peak memory usage, where the
569                      saved memory size will be equal to the total gradients
570                      size. Moreover, it avoids the overhead of copying between
571                      gradients and ``allreduce`` communication buckets. When
572                      gradients are views, ``detach_()`` cannot be called on the
573                      gradients. If hitting such errors, please fix it by
574                      referring to the :meth:`~torch.optim.Optimizer.zero_grad`
575                      function in ``torch/optim/optimizer.py`` as a solution.
576                      Note that gradients will be views after first iteration, so
577                      the peak memory saving should be checked after first iteration.
578        static_graph (bool): When set to ``True``, DDP knows the trained graph is
579                     static. Static graph means 1) The set of used and unused
580                     parameters will not change during the whole training loop; in
581                     this case, it does not matter whether users set
582                     ``find_unused_parameters = True`` or not. 2) How the graph is trained
583                     will not change during the whole training loop (meaning there is
584                     no control flow depending on iterations).
585                     When static_graph is set to be ``True``, DDP will support cases that
586                     can not be supported in the past:
587                     1) Reentrant backwards.
588                     2) Activation checkpointing multiple times.
589                     3) Activation checkpointing when model has unused parameters.
590                     4) There are model parameters that are outside of forward function.
591                     5) Potentially improve performance when there are unused parameters,
592                     as DDP will not search graph in each iteration to detect unused
593                     parameters when static_graph is set to be ``True``.
594                     To check whether you can set static_graph to be ``True``, one way is to
595                     check ddp logging data at the end of your previous model training,
596                     if ``ddp_logging_data.get("can_set_static_graph") == True``, mostly you
597                     can set ``static_graph = True`` as well.
598
599                     Example::
600                         >>> # xdoctest: +SKIP("undefined variables")
601                         >>> model_DDP = torch.nn.parallel.DistributedDataParallel(model)
602                         >>> # Training loop
603                         >>> ...
604                         >>> ddp_logging_data = model_DDP._get_ddp_logging_data()
605                         >>> static_graph = ddp_logging_data.get("can_set_static_graph")
606        delay_all_reduce_named_params (list of tuple of str and torch.nn.Parameter): a list
607                    of named parameters whose all reduce will be delayed when the gradient of
608                    the parameter specified in ``param_to_hook_all_reduce`` is ready. Other
609                    arguments of DDP do not apply to named params specified in this argument
610                    as these named params will be ignored by DDP reducer.
611        param_to_hook_all_reduce (torch.nn.Parameter): a parameter to hook delayed all reduce
612                    of parameters specified in ``delay_all_reduce_named_params``.
613
614
615    Attributes:
616        module (Module): the module to be parallelized.
617
618    Example::
619
620        >>> # xdoctest: +SKIP("undefined variables")
621        >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
622        >>> net = torch.nn.parallel.DistributedDataParallel(model)
623    """
624
625    # used to track whether the given thread is inside ddp forward for torchdynamo purposes
626    _active_ddp_module: Optional["DistributedDataParallel"] = None
627
628    def __init__(
629        self,
630        module,
631        device_ids=None,
632        output_device=None,
633        dim=0,
634        broadcast_buffers=True,
635        process_group=None,
636        bucket_cap_mb=None,
637        find_unused_parameters=False,
638        check_reduction=False,
639        gradient_as_bucket_view=False,
640        static_graph=False,
641        delay_all_reduce_named_params=None,
642        param_to_hook_all_reduce=None,
643        mixed_precision: Optional[_MixedPrecision] = None,
644        device_mesh=None,
645    ):
646        super().__init__()
647        Joinable.__init__(self)
648        self.logger: Optional[dist.Logger] = None
649        if bool(delay_all_reduce_named_params is not None) != bool(
650            param_to_hook_all_reduce is not None
651        ):
652            self._log_and_throw(
653                ValueError,
654                "delay_all_reduce_named_params and param_to_hook_all_reduce "
655                "need to be set at the same time.",
656            )
657
658        if process_group and device_mesh is not None:
659            raise RuntimeError(
660                "Cannot specify both process_group and device_mesh arguments."
661            )
662        elif process_group is None and device_mesh is None:
663            self.process_group = _get_default_group()
664        elif device_mesh is None:
665            self.process_group = process_group
666        else:
667            if device_mesh.ndim != 1:
668                raise RuntimeError(
669                    f"Only 1D device mesh is supported, but got {device_mesh}."
670                )
671            self.device_mesh = device_mesh
672            self.process_group = device_mesh.get_group(mesh_dim=0)
673            from torch.distributed.device_mesh import _mesh_resources
674
675            root_mesh = _mesh_resources.get_root_mesh(device_mesh)
676            # if a root mesh is not the same as device_mesh,
677            # meaning the device_mesh is sliced out from the root mesh.
678            if root_mesh != device_mesh:
679                # TODO: This is a temporary work around to enable DDP + TP.
680                # We should do the logic in DDP so that the 2D implementation is
681                # sound and the state_dict works out of the box.
682                # This has to be done before check UninitializedParameter.
683                from torch.distributed.tensor.parallel.ddp import (
684                    _pre_dp_module_transform,
685                )
686
687                _pre_dp_module_transform(module)
688
689        self._delay_all_reduce_params = []
690        if hasattr(module, "_ddp_params_and_buffers_to_ignore"):
691            self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore)
692        else:
693            self.parameters_to_ignore = set()
694        if delay_all_reduce_named_params is not None:
695            for name, param in delay_all_reduce_named_params:
696                self.parameters_to_ignore.add(name)
697                self._delay_all_reduce_params.append(param)
698
699        self._module_parameters = [
700            p
701            for n, p in module.named_parameters()
702            if n not in self.parameters_to_ignore
703        ]
704        if not any(p.requires_grad for p in self._module_parameters):
705            if len(self._delay_all_reduce_params):
706                logger.info("Delay the AllReduce of all parameters.")
707            else:
708                self._log_and_throw(
709                    RuntimeError,
710                    "DistributedDataParallel is not needed when a module "
711                    "doesn't have any parameter that requires a gradient.",
712                )
713
714        if device_ids is not None and len(device_ids) > 1:
715            self._log_and_throw(
716                ValueError,
717                "device_ids can only be None or contain a single element.",
718            )
719
720        self.is_multi_device_module = (
721            len({p.device for p in self._module_parameters}) > 1
722        )
723        distinct_device_types = {
724            p.device.type for p in self._module_parameters if p.device is not None
725        }
726        if len(distinct_device_types) != 1:
727            self._log_and_throw(
728                ValueError,
729                "DistributedDataParallel's input module must be on "
730                f"the same type of devices, but input module parameters locate in {distinct_device_types}.",
731            )
732
733        self.device_type = next(iter(distinct_device_types))
734
735        if (
736            device_ids is None
737            or len(device_ids) == 0  # For backward compatibility.
738            or self.device_type == "cpu"
739            or self.is_multi_device_module
740        ):
741            if device_ids or output_device:
742                self._log_and_throw(
743                    ValueError,
744                    "DistributedDataParallel device_ids and output_device arguments "
745                    "only work with single-device/multiple-device GPU modules or CPU modules, "
746                    f"but got device_ids {device_ids}, output_device {output_device}, "
747                    f"and module parameters {({p.device for p in self._module_parameters})}.",
748                )
749
750            self.device_ids = None
751            self.output_device = None
752        else:
753            self.device_ids = [_get_device_index(x, True) for x in device_ids]
754
755            if output_device is None:
756                output_device = device_ids[0]
757
758            self.output_device = _get_device_index(output_device, True)
759
760        self.static_graph = False
761        self.dim = dim
762        self.module = module
763        self.device = next(iter(self._module_parameters)).device
764        self.broadcast_buffers = broadcast_buffers
765        self.find_unused_parameters = find_unused_parameters
766        self.require_backward_grad_sync = True
767        self.require_forward_param_sync = True
768        self.gradient_as_bucket_view = gradient_as_bucket_view
769        self.mixed_precision = mixed_precision
770        if self.mixed_precision is not None:
771            logger.warning("Received mixed precision config %s", self.mixed_precision)
772
773        if check_reduction:
774            # This argument is no longer used since the reducer
775            # will ensure reduction completes even if some parameters
776            # do not receive gradients.
777            warnings.warn(
778                "The `check_reduction` argument in `DistributedDataParallel` "
779                "module is deprecated. Please avoid using it.",
780                FutureWarning,
781                stacklevel=2,
782            )
783
784        # Check that a module does not have Uninitialized parameters
785        for param in self._module_parameters:
786            if isinstance(param, torch.nn.parameter.UninitializedParameter):
787                self._log_and_throw(
788                    RuntimeError,
789                    "Modules with uninitialized parameters can't be used with `DistributedDataParallel`. "
790                    "Run a dummy forward pass to correctly initialize the modules",
791                )
792        # used for intra-node param sync and inter-node sync as well
793        self.broadcast_bucket_size = int(250 * 1024 * 1024)
794
795        # reduction bucket size
796        if bucket_cap_mb is None:
797            # default case (bucket cap is 25 MiB)
798            bucket_cap_mb = 25
799            self.bucket_bytes_cap_default = True
800        else:
801            self.bucket_bytes_cap_default = False
802        self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
803
804        # Whether to perform input tensor CPU to GPU copies on a side-stream
805        self.use_side_stream_for_tensor_copies = (
806            os.environ.get("PYTORCH_DDP_USE_SIDE_STREAM", "1") == "1"
807        )
808
809        # Initialize gradient buffers and register all reduce hook
810        self._delay_grad_buffer: Optional[torch.Tensor] = None
811        self._delay_grad_views: List[torch.Tensor] = []
812        self._delay_all_reduce_all_params = False
813        if len(self._delay_all_reduce_params) != 0:
814            self._register_delay_all_reduce_hook(
815                bucket_cap_mb=bucket_cap_mb,
816                param_to_hook_all_reduce=param_to_hook_all_reduce,
817                device_ids=device_ids,
818            )
819            if self._delay_all_reduce_all_params:
820                return
821
822        # Build parameters for reducer.
823        parameters, expect_sparse_gradient = self._build_params_for_reducer()
824        # Verify model equivalence.
825        _verify_param_shape_across_processes(self.process_group, parameters)
826        # Sync params and buffers. Ensures all DDP models start off at the same value.
827        _sync_module_states(
828            module=self.module,
829            process_group=self.process_group,
830            broadcast_bucket_size=self.broadcast_bucket_size,
831            src=0,
832            params_and_buffers_to_ignore=self.parameters_to_ignore,
833            broadcast_buffers=self.broadcast_buffers,
834        )
835        # In debug mode, build a mapping of parameter index -> parameter.
836        param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
837
838        # Builds reducer.
839        self._ddp_init_helper(
840            parameters,
841            expect_sparse_gradient,
842            param_to_name_mapping,
843            static_graph,
844        )
845        self._comm_hooks: List[Tuple[Callable, object]] = []
846
847        if self.mixed_precision is not None:
848            _setup_mixed_precision_params(self.mixed_precision, self.module)
849            _cast_buffers(self.mixed_precision, self.module)
850            # Stream used for async low precision copies.
851            self._mp_stream = torch.cuda.Stream()
852            self._submodule_to_event = defaultdict(deque)  # type: ignore[var-annotated]
853            # Add forward pre-hook to root module to kick off copies to lower
854            # precision.
855            self.module.register_forward_pre_hook(
856                self._root_copy_hook, prepend=False, with_kwargs=True
857            )
858            # Add forward pre hook to all submodules to wait for copy events
859            # before running computation.
860            for module in self.module.modules():
861                module.register_forward_pre_hook(
862                    self._module_wait_for_copy_hook,
863                    prepend=False,
864                    with_kwargs=True,
865                )
866            # Set up callbacks in backward to upcast and use full precision
867            # params. TODO (rohan-varma): Make this compose with general
868            # comm hooks and apply_optimizer_in_backward. Importing inline to
869            # avoid circular import issue.
870            from torch.distributed.algorithms.ddp_comm_hooks.mixed_precision_hooks import (
871                _AllreduceUpcastHookState,
872                _reducer_allreduce_and_upcast_hook,
873            )
874
875            upcast_hook_state = _AllreduceUpcastHookState(
876                ddp_weakref=weakref.ref(self),
877                upcast_stream=torch.cuda.Stream(),
878            )
879            self.register_comm_hook(
880                upcast_hook_state,
881                _reducer_allreduce_and_upcast_hook,
882            )
883            # Inform reducer of reduced precision param dtype for correctness
884            # of type checks between gradient and bucket.
885            self.reducer._set_mixed_precision_param_dtype(  # type: ignore[attr-defined]
886                self.mixed_precision.param_dtype
887            )
888
889        self._has_rebuilt_buckets = False
890
891        if static_graph:
892            self._set_static_graph()
893
894        self._lazy_init_ran = False
895
896        # Register the AccumulateGrad post hooks if optimize_ddp is
897        # True. The hooks will be deregistered if compiled_autograd is not
898        # enabled.
899        self._accum_grad_hooks: List[RemovableHandle] = []
900        optimize_ddp = torch._dynamo.config._get_optimize_ddp_mode()
901        self._use_python_reducer = optimize_ddp in (
902            "python_reducer",
903            "python_reducer_without_compiled_forward",
904        )
905        if self._use_python_reducer:
906            torch._inductor.config._fuse_ddp_communication = True
907            torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb
908            # Directly adding this to the trace rule will disturb the users
909            # who are using DDPOptimizer.
910            torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST.add(
911                "torch.nn.parallel.distributed"
912            )
913            torch._dynamo.trace_rules.get_legacy_mod_inlinelist.cache_clear()
914        self._force_to_disable_cpp_reducer = (
915            optimize_ddp == "python_reducer_without_compiled_forward"
916        )
917        if self._use_python_reducer:
918            self._register_accum_grad_hook()
919
920        # Whether or not DDPSink performs a clone.
921        self._ddp_sink_clone = True
922
923    def _register_accum_grad_hook(self):
924        import torch.distributed._functional_collectives as fcol
925
926        def compiled_accum_grad_hook(
927            param,
928            *,
929            param_index: int,
930        ):
931            if not self.require_backward_grad_sync:
932                return
933
934            if param.grad is None:
935                return
936
937            if self._comm_hooks:
938                for hook, state in self._comm_hooks:
939                    hook(state, (param.grad, param))
940            else:
941                gradient = param.grad / self.process_group.size()
942                gradient = fcol.all_reduce(gradient, "sum", self.process_group)
943                param.grad.copy_(gradient)
944
945        for index, param in enumerate(self._module_parameters):
946            if not param.requires_grad:
947                continue
948            self._accum_grad_hooks.append(
949                param.register_post_accumulate_grad_hook(
950                    functools.partial(
951                        compiled_accum_grad_hook,
952                        param_index=index,
953                    )
954                )
955            )
956
957    def _delayed_all_reduce_hook(self, grad):
958        world_size = dist.get_world_size(self.process_group)
959
960        self._delay_grad_buffer.div_(world_size)  # type: ignore[union-attr]
961        _ = dist.all_reduce(
962            self._delay_grad_buffer, group=self.process_group, async_op=True
963        )
964        return grad
965
966    def _register_delay_all_reduce_hook(
967        self,
968        bucket_cap_mb,
969        param_to_hook_all_reduce,
970        device_ids,
971    ):
972        # 1. Create gradient buffer
973        device = torch.device("cpu") if device_ids is None else device_ids[0]
974        self._delay_grad_buffer = torch.zeros(
975            sum(p.numel() for p in self._delay_all_reduce_params),
976            device=device,
977        )
978
979        # 2. Broadcast the parameters
980        detached_params = [p.detach() for p in self._delay_all_reduce_params]
981        dist._broadcast_coalesced(self.process_group, detached_params, bucket_cap_mb, 0)
982
983        # 3. Hook all reduce to the specified parameter
984        param_to_hook_all_reduce.register_hook(self._delayed_all_reduce_hook)
985
986        # 4. Build tensor views for gradients
987        offset = 0
988        for param in self._delay_all_reduce_params:
989            grad_view = self._delay_grad_buffer[offset : (offset + param.numel())].view(
990                param.shape
991            )
992            self._delay_grad_views.append(grad_view)
993            offset = offset + param.numel()
994
995        # 5. Check whether the all reduce of all params requiring grad is delayed.
996        for module_name, module in self.module.named_modules():
997            for param_name, param in module.named_parameters(recurse=False):
998                if param.requires_grad:
999                    full_name = f"{module_name}.{param_name}"
1000                    if full_name not in self.parameters_to_ignore:
1001                        # There is at least a param whose all reduce will not be delayed.
1002                        # In this case, we should not set self._delay_all_reduce_all_params
1003                        # to True.
1004                        return
1005        self._delay_all_reduce_all_params = True
1006
1007    def _setup_in_backward_optimizers(self):
1008        # Check if user has used apply_optim_in_backward to overlap optimizer
1009        # step + DDP backward. Current constraints:
1010        # 1. Only allreduce is supported at the moment, no custom communication.
1011        # 2. For DDP-managed parameters that have their optimizer run in
1012        # backward, their gradients are set to ``None``. If your use case
1013        # requires DDP parameters grad not to be set to ``None`` after their
1014        # in-backward optimizer runs, please ping
1015        # https://github.com/pytorch/pytorch/issues/90052.
1016        # NOTE: we use self._module_parameters instead of .parameters() since
1017        # the former excludes ignored (non-DDP managed) parameters.
1018        if any(hasattr(p, "_in_backward_optimizers") for p in self._module_parameters):
1019            torch._C._log_api_usage_once("ddp.optimizer_in_backward")
1020            # Remove hooks that apply_optim_in_backward had registered because
1021            # DDP customizes how optimizer is overlapped with backward due to
1022            # the allreduce.
1023            param_to_handle_map = (
1024                dist.optim.apply_optimizer_in_backward.param_to_optim_hook_handle_map
1025            )
1026            for p in self._module_parameters:
1027                for handle in param_to_handle_map.get(p, []):
1028                    handle.remove()
1029
1030            # Need a weakref to DDP instance to run all_reduce (from reducer)
1031            # and get managed DDP parameters.
1032            ddp_weakref = weakref.ref(self)
1033            # Note: importing in function, otherwise this will cause a circular
1034            # import.
1035            from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
1036                _apply_optim_in_backward_hook,
1037            )
1038
1039            self.register_comm_hook(
1040                ddp_weakref,
1041                _apply_optim_in_backward_hook(
1042                    gradient_is_bucket_view=self.gradient_as_bucket_view
1043                ),
1044            )
1045
1046            self.reducer._set_optimizer_in_backward()  # type: ignore[attr-defined]
1047
1048    def _fire_reducer_autograd_hook(self, idx, *unused):
1049        """
1050        Fire the reducer's autograd hook to allreduce params in a Reducer bucket.
1051
1052        Note that this is only used during mixed precision training as the
1053        Reducer's hooks installed during construction time would not be called
1054        as we're working in the low precision parameter setting.
1055        """
1056        self.reducer._autograd_hook(idx)  # type: ignore[attr-defined]
1057
1058    def _root_copy_hook(self, *args: Any, **kwargs: Any) -> None:
1059        """
1060        For DDP mixed precision, put low precision copies on separate stream and create events to wait for them.
1061
1062        When training with DDP mixed precision, this root pre-forward hook kicks
1063        off low precision copies on a separate stream and creates respective
1064        events to wait for them.
1065        """
1066        # Clear out previous iteration submodule to event. This is because we
1067        # may have populated some events for modules that didn't end up being
1068        # used.
1069        self._submodule_to_event = defaultdict(deque)  # type: ignore[var-annotated]
1070        with torch.cuda.stream(self._mp_stream):
1071            for submodule in self.module.modules():
1072                for param in submodule.parameters(recurse=False):
1073                    # Do not cast DDP ignored parameters.
1074                    if hasattr(param, "_ddp_ignored") and param._ddp_ignored:
1075                        continue
1076                    _alloc_storage(param._mp_param, param.size())
1077                    # copy() implicitly casts to low precision
1078                    with torch.no_grad():
1079                        param._mp_param.copy_(param.data)
1080                        # TODO: when zero_grad(set_to_none=False) or in grad
1081                        # accumulation case, accumulated grads can be in fp32
1082                        # which can cause errors when running DDP backwards due
1083                        # to mismatched incoming and accumulated gradient types.
1084                        # So we manually cast the accumulated grad down for now,
1085                        # in the future we may shift to FSDP style gradient
1086                        # accumulation management where the accumulated gradient
1087                        # is saved and .grad field is set to None, bypassing
1088                        # this issue.
1089                        if param.grad is not None:
1090                            param.grad.data = param.grad.to(
1091                                self.mixed_precision.param_dtype  # type: ignore[union-attr]
1092                            )
1093                    param.data = param._mp_param
1094                copy_event = torch.cuda.Event()
1095                copy_event.record()
1096                self._submodule_to_event[submodule].append(copy_event)
1097
1098    def _module_wait_for_copy_hook(
1099        self,
1100        module,
1101        *args: Any,
1102        **kwargs: Any,
1103    ) -> None:
1104        """Before carrying out computation, wait on the appropriate event to ensure low precision copies have finished."""
1105        try:
1106            event = self._submodule_to_event[module].popleft()
1107        except IndexError:
1108            # copy event has already been waited on
1109            return
1110
1111        event.wait(stream=torch.cuda.current_stream())
1112        for p in module.parameters(recurse=False):
1113            # Don't register hooks if param does not require grad
1114            if not p.requires_grad or (hasattr(p, "_ddp_ignored") and p._ddp_ignored):
1115                continue
1116            # We need to register autograd hook here instead of DDP's ctor
1117            # since we're working with the low precision param. Register them
1118            # via obtaining the gradient accumulator.
1119            tmp = p.expand_as(p)
1120            grad_acc = tmp.grad_fn.next_functions[0][0]
1121
1122            hook = grad_acc.register_hook(
1123                functools.partial(self._fire_reducer_autograd_hook, p._idx)
1124            )
1125            p._ddp_mp_hook_state = (grad_acc, hook)
1126
1127    def _log_and_throw(self, err_type, err_msg):
1128        if self.logger is not None:
1129            self.logger.set_error_and_log(f"{str(err_type)}: {err_msg}")
1130        raise err_type(err_msg)
1131
1132    def _ddp_init_helper(
1133        self,
1134        parameters,
1135        expect_sparse_gradient,
1136        param_to_name_mapping,
1137        static_graph,
1138    ):
1139        """
1140        DDP init helper function to manage parameters, grad hooks, logging, and SyncBatchNorm.
1141
1142        Initialization helper function that does the following:
1143        (1) bucketing the parameters for reductions
1144        (2) resetting the bucketing states
1145        (3) registering the grad hooks
1146        (4) Logging construction-time DDP logging data
1147        (5) passing a handle of DDP to SyncBatchNorm Layer
1148        """
1149        # Notice, the parameters order is not in the order in which they are used,
1150        # especially in models with control flow.
1151        #
1152        # Alongside parameters are not presented in the real execution order,
1153        # if a certain model happens to also
1154        #   1) have other collectives comm ops in its backward graph.
1155        #   2) have unused parameter in subset ranks of the whole world.
1156        # bucketing could insert ALL-REDUCE comm op too early on the rank with unused parameter,
1157        # matching up with other collectives comm ops on other ranks unexpectedly.
1158        #
1159        # In order to handle this corner case, when the parameters are not in the real execution order,
1160        # we don't do bucketing, thus only one ALL-REDUCE is inserted after all the gradients
1161        # of the whole graph are computed.
1162        #
1163        # Notice, here we only disable bucketing for the first iteration.
1164        # After the first iteration, it's OK to rebuild buckets,
1165        # because "bucket rebuild" bucketizes parameters based on its real execution order in backward graph.
1166
1167        # Can remove this branching once #73732 is landed.
1168        if static_graph is True or self.find_unused_parameters is False:
1169            bucket_size_limits = [sys.maxsize]
1170        else:
1171            if self.bucket_bytes_cap_default:
1172                bucket_size_limits = [
1173                    dist._DEFAULT_FIRST_BUCKET_BYTES,
1174                    self.bucket_bytes_cap,
1175                ]
1176            else:
1177                bucket_size_limits = [self.bucket_bytes_cap]
1178        (
1179            bucket_indices,
1180            per_bucket_size_limits,
1181        ) = dist._compute_bucket_assignment_by_size(
1182            parameters,
1183            bucket_size_limits,
1184            expect_sparse_gradient,
1185        )
1186
1187        # Remember index for parameters if we are in mixed precision, as we
1188        # need to pass in index to Reducer's autograd hook via python.
1189        if self.mixed_precision is not None:
1190            for i, p in enumerate(parameters):
1191                p._idx = i
1192
1193        # Note: reverse list of buckets because we want to approximate the
1194        # order in which their gradients are produced, and assume they
1195        # are used in the forward pass in the order they are defined.
1196        self.reducer = dist.Reducer(
1197            parameters,
1198            list(reversed(bucket_indices)),
1199            list(reversed(per_bucket_size_limits)),
1200            self.process_group,
1201            expect_sparse_gradient,
1202            # The bucket size limit is specified in the constructor.
1203            # Additionally, we allow for a single small bucket for parameters
1204            # that are defined first, such that their gradients don't spill into
1205            # a much larger bucket, adding unnecessary latency after gradient
1206            # computation finishes. Experiments showed 1MB is a reasonable value.
1207            self.bucket_bytes_cap,
1208            self.find_unused_parameters,
1209            self.gradient_as_bucket_view,
1210            param_to_name_mapping,
1211            # User can set dist._DEFAULT_FIRST_BUCKET_BYTES to tune DDP first
1212            # bucket.
1213            (
1214                dist._DEFAULT_FIRST_BUCKET_BYTES
1215                if self.bucket_bytes_cap_default
1216                else self.bucket_bytes_cap
1217            ),
1218        )
1219
1220        self.logger = dist.Logger(self.reducer)
1221        # Set as a weak reference to avoid reference cycle between
1222        # logger and reducer.
1223        self.reducer.set_logger(self.logger)
1224
1225        has_sync_bn = False
1226        for submodule in self.module.modules():
1227            if isinstance(submodule, torch.nn.SyncBatchNorm):
1228                has_sync_bn = True
1229                break
1230
1231        # Set logging data that can be got during construction time.
1232        self.logger.set_construction_data_and_log(
1233            self.module.__class__.__name__,
1234            [] if self.device_ids is None else self.device_ids,
1235            -1 if self.output_device is None else self.output_device,
1236            self.broadcast_buffers,
1237            has_sync_bn,
1238            static_graph,
1239        )
1240
1241        # passing a handle to torch.nn.SyncBatchNorm layer
1242        self._passing_sync_batchnorm_handle(self.module)
1243
1244    def __getstate__(self):
1245        self._check_default_group()
1246        attrs = copy.copy(self.__dict__)
1247        del attrs["process_group"]
1248        del attrs["reducer"]
1249        del attrs["logger"]
1250        return attrs
1251
1252    def __setstate__(self, state):
1253        # If serializable, then the process group should be the default one
1254        self.process_group = _get_default_group()
1255        super().__setstate__(state)
1256        self.__dict__.setdefault("require_forward_param_sync", True)
1257        self.__dict__.setdefault("require_backward_grad_sync", True)
1258        parameters, expect_sparse_gradient = self._build_params_for_reducer()
1259        # In debug mode, build a mapping of parameter index -> parameter.
1260        param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
1261        # Builds reducer.
1262        self._ddp_init_helper(
1263            parameters,
1264            expect_sparse_gradient,
1265            param_to_name_mapping,
1266            self.static_graph,
1267        )
1268        if self.static_graph:
1269            self.reducer._set_static_graph()
1270            assert self.logger is not None
1271            self.logger._set_static_graph()
1272
1273    def _build_params_for_reducer(self):
1274        # Build tuple of (module, parameter) for all parameters that require grads.
1275        modules_and_parameters = [
1276            (module, parameter)
1277            for module_name, module in self.module.named_modules()
1278            for parameter in [
1279                param
1280                # Note that we access module.named_parameters instead of
1281                # parameters(module). parameters(module) is only needed in the
1282                # single-process multi device case, where it accesses replicated
1283                # parameters through _former_parameters.
1284                for param_name, param in module.named_parameters(recurse=False)
1285                if param.requires_grad
1286                and f"{module_name}.{param_name}" not in self.parameters_to_ignore
1287            ]
1288        ]
1289
1290        # Deduplicate any parameters that might be shared across child modules.
1291        memo = set()
1292        modules_and_parameters = [
1293            # "p not in memo" is the deduplication check.
1294            # "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed.
1295            (m, p)
1296            for m, p in modules_and_parameters
1297            if p not in memo and not memo.add(p)  # type: ignore[func-returns-value]
1298        ]
1299
1300        # Build list of parameters.
1301        parameters = [parameter for _, parameter in modules_and_parameters]
1302
1303        # Checks if a module will produce a sparse gradient.
1304        def produces_sparse_gradient(module):
1305            if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)):
1306                return module.sparse
1307            return False
1308
1309        # Build list of booleans indicating whether or not to expect sparse
1310        # gradients for the corresponding parameters.
1311        expect_sparse_gradient = [
1312            produces_sparse_gradient(module) for module, _ in modules_and_parameters
1313        ]
1314
1315        self._assign_modules_buffers()
1316
1317        return parameters, expect_sparse_gradient
1318
1319    def _assign_modules_buffers(self):
1320        """
1321        Assign self.module.named_buffers to self.modules_buffers.
1322
1323        Assigns module buffers to self.modules_buffers which are then used to
1324        broadcast across ranks when broadcast_buffers=True. Note that this
1325        must be called every time buffers need to be synced because buffers can
1326        be reassigned by user module,
1327        see https://github.com/pytorch/pytorch/issues/63916.
1328        """
1329        # Collect buffers for modules, filtering out buffers that should be ignored.
1330        named_module_buffers = [
1331            (buffer, buffer_name)
1332            for buffer_name, buffer in self.module.named_buffers()
1333            if buffer_name not in self.parameters_to_ignore
1334        ]
1335        self.modules_buffers = [
1336            buffer for (buffer, buffer_name) in named_module_buffers
1337        ]
1338        # Dict[str, tensor] representing module buffers not ignored by DDP.
1339        self.named_module_buffers = {
1340            buffer_name: buffer for (buffer, buffer_name) in named_module_buffers
1341        }
1342
1343    def _build_debug_param_to_name_mapping(self, parameters):
1344        param_to_param_index = {parameters[i]: i for i in range(len(parameters))}
1345        param_set = set(parameters)
1346        param_index_to_param_fqn = {}
1347        for module_name, module in self.module.named_modules():
1348            for param_name, param in module.named_parameters(recurse=False):
1349                fqn = f"{module_name}.{param_name}"
1350                # Bypass ignored parameters since those are not reduced by DDP
1351                # to begin with.
1352                if fqn not in self.parameters_to_ignore and param.requires_grad:
1353                    if param not in param_set:
1354                        self._log_and_throw(
1355                            ValueError,
1356                            f"Param with name {fqn} found in module parameters, but not DDP parameters."
1357                            " This indicates a bug in DDP, please report an issue to PyTorch.",
1358                        )
1359                    param_index = param_to_param_index[param]
1360                    param_index_to_param_fqn[param_index] = fqn
1361
1362        # Ensure we covered all parameters
1363        if len(param_set) != len(param_index_to_param_fqn):
1364            self._log_and_throw(
1365                ValueError,
1366                (
1367                    "Expected param to name mapping to cover all parameters, but"
1368                    f" got conflicting lengths: {len(param_set)} vs "
1369                    f"{len(param_index_to_param_fqn)}. This indicates a bug in DDP"
1370                    ", please report an issue to PyTorch."
1371                ),
1372            )
1373
1374        return param_index_to_param_fqn
1375
1376    def _get_parameters(self, m, recurse=True):
1377        """Return a generator of module parameters."""
1378
1379        def model_parameters(m):
1380            ps = (
1381                m._former_parameters.values()
1382                if hasattr(m, "_former_parameters")
1383                else m.parameters(recurse=False)
1384            )
1385            yield from ps
1386
1387        for mod in m.modules() if recurse else [m]:
1388            yield from model_parameters(mod)
1389
1390    def _check_default_group(self):
1391        pickle_not_supported = False
1392        try:
1393            if self.process_group != _get_default_group():
1394                pickle_not_supported = True
1395        except RuntimeError:
1396            pickle_not_supported = True
1397
1398        if pickle_not_supported:
1399            self._log_and_throw(
1400                RuntimeError,
1401                "DDP Pickling/Unpickling are only supported "
1402                "when using DDP with the default process "
1403                "group. That is, when you have called "
1404                "init_process_group and have not passed "
1405                "process_group argument to DDP constructor",
1406            )
1407
1408    @contextmanager
1409    def no_sync(self):
1410        r"""
1411        Context manager to disable gradient synchronizations across DDP processes.
1412
1413        Within this context, gradients will be accumulated on module
1414        variables, which will later be synchronized in the first
1415        forward-backward pass exiting the context.
1416
1417        Example::
1418
1419            >>> # xdoctest: +SKIP("undefined variables")
1420            >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
1421            >>> with ddp.no_sync():
1422            >>>     for input in inputs:
1423            >>>         ddp(input).backward()  # no synchronization, accumulate grads
1424            >>> ddp(another_input).backward()  # synchronize grads
1425
1426        .. warning::
1427            The forward pass should be included inside the context manager, or
1428            else gradients will still be synchronized.
1429        """
1430        old_require_backward_grad_sync = self.require_backward_grad_sync
1431        self.require_backward_grad_sync = False
1432        try:
1433            yield
1434        finally:
1435            self.require_backward_grad_sync = old_require_backward_grad_sync
1436
1437    @classmethod
1438    def _get_active_ddp_module(cls):
1439        """`TorchDynamo` requires DDP's status and module for cooperative optimization."""
1440        return cls._active_ddp_module
1441
1442    # note, this ctxmgr function is marked 'skip' in torchdynamo, so dynamo only kicks in
1443    # for the 'module_to_run' underneath
1444    # see torch._dynamo/eval_frame.py TorchPatcher.patch for more details
1445    @contextmanager
1446    @torch._disable_dynamo(recursive=False)
1447    def _inside_ddp_forward(self):
1448        DistributedDataParallel._active_ddp_module = self
1449        try:
1450            yield
1451        finally:
1452            DistributedDataParallel._active_ddp_module = None
1453
1454    def _run_ddp_forward(self, *inputs, **kwargs):
1455        if self._use_python_reducer:
1456            return self.module(*inputs, **kwargs)  # type: ignore[index]
1457        else:
1458            with self._inside_ddp_forward():
1459                return self.module(*inputs, **kwargs)  # type: ignore[index]
1460
1461    def _clear_grad_buffer(self):
1462        # Making param.grad points to the grad buffers before backward is based on the
1463        # assumption that the grad accumulation is done in place in autograd engine,
1464        # for some edge cases, if the grad accumulation in autograd engine is not in
1465        # place, then the param.grad and grad buffers are detached.
1466        if self._delay_grad_buffer is not None:
1467            # We batch zero_grad for all params by resetting the whole grad
1468            # buffer when the grad of all params is set to None.
1469            all_param_grad_none = all(
1470                param.grad is None for param in self._delay_all_reduce_params
1471            )
1472
1473            for index, param in enumerate(self._delay_all_reduce_params):
1474                if param.grad is None:
1475                    param.grad = self._delay_grad_views[index]
1476                    if not all_param_grad_none:
1477                        param.grad.zero_()
1478
1479            if all_param_grad_none:
1480                self._delay_grad_buffer.zero_()
1481
1482    def _lazy_init(self):
1483        # Initialization for DDP that occurs after construction, but lazily
1484        # before the first forward pass.
1485        self._setup_in_backward_optimizers()
1486        self._lazy_init_ran = True
1487
1488    def _should_disable_cpp_reducer(self) -> bool:
1489        return self._use_python_reducer and (
1490            torch._utils.is_compiling() or self._force_to_disable_cpp_reducer
1491        )
1492
1493    def _pre_forward(self, *inputs, **kwargs):
1494        if self._should_disable_cpp_reducer():
1495            return inputs, kwargs
1496
1497        # Disable the python reducer if compiled_autograd is not enabled.
1498        if self._accum_grad_hooks:
1499            for index, h in enumerate(self._accum_grad_hooks):
1500                h.remove()
1501            self._accum_grad_hooks.clear()
1502
1503        if not self._lazy_init_ran and not torch._utils.is_compiling():
1504            self._lazy_init()
1505
1506        if self._delay_all_reduce_all_params:
1507            return inputs, kwargs
1508
1509        if torch.is_grad_enabled() and self.require_backward_grad_sync:
1510            assert self.logger is not None
1511            self.logger.set_runtime_stats_and_log()
1512            self.reducer.prepare_for_forward()
1513
1514        # Notify the join context that this process has not joined, if
1515        # needed
1516        work = Join.notify_join_context(self)
1517        if work:
1518            self.reducer._set_forward_pass_work_handle(
1519                work, self._divide_by_initial_world_size  # type: ignore[arg-type]
1520            )
1521
1522        # Calling _rebuild_buckets before forward computation,
1523        # It may allocate new buckets before deallocating old buckets
1524        # inside _rebuild_buckets. To save peak memory usage,
1525        # call _rebuild_buckets before the peak memory usage increases
1526        # during forward computation.
1527        # This should be called only once during whole training period.
1528        if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
1529            logger.info("Reducer buckets have been rebuilt in this iteration.")
1530            self._has_rebuilt_buckets = True
1531
1532        # sync params according to location (before/after forward) user
1533        # specified as part of hook, if hook was specified.
1534        if self._check_sync_bufs_pre_fwd():
1535            self._sync_buffers()
1536
1537        if self._join_config.enable:
1538            # Notify joined ranks whether they should sync in backwards pass or not.
1539            self._check_global_requires_backward_grad_sync(is_joined_rank=False)
1540
1541        if self.device_ids:
1542            moved_inputs, moved_kwargs = _to_kwargs(
1543                inputs,
1544                kwargs,
1545                torch.device(self.device_type, self.device_ids[0]),
1546                self.use_side_stream_for_tensor_copies,
1547            )
1548            args, kwargs = moved_inputs[0], moved_kwargs[0]
1549            # Cast inputs to reduced precision if needed.
1550            if self.mixed_precision is not None:
1551                args, kwargs = _cast_forward_inputs(
1552                    self.mixed_precision.param_dtype,
1553                    *args,
1554                    **kwargs,
1555                )
1556            return args, kwargs
1557        else:
1558            # Cast inputs to reduced precision if needed.
1559            # TODO (rohan-varma) test this codepath.
1560            if self.mixed_precision is not None:
1561                inputs, kwargs = _cast_forward_inputs(
1562                    self.mixed_precision.param_dtype,
1563                    *inputs,
1564                    **kwargs,
1565                )
1566            return inputs, kwargs
1567
1568    def _post_forward(self, output):
1569        if self._should_disable_cpp_reducer():
1570            return output
1571
1572        if self._delay_all_reduce_all_params:
1573            self._clear_grad_buffer()
1574            return output
1575
1576        # sync params according to location (before/after forward) user
1577        # specified as part of hook, if hook was specified.
1578        if self._check_sync_bufs_post_fwd():
1579            self._sync_buffers()
1580
1581        if torch.is_grad_enabled() and self.require_backward_grad_sync:
1582            self.require_forward_param_sync = True
1583            # We'll return the output object verbatim since it is a freeform
1584            # object. We need to find any tensors in this object, though,
1585            # because we need to figure out which parameters were used during
1586            # this forward pass, to ensure we short circuit reduction for any
1587            # unused parameters. Only if `find_unused_parameters` is set.
1588            if self.find_unused_parameters and not self.static_graph:
1589                # Do not need to populate this for static graph.
1590                self.reducer.prepare_for_backward(list(_find_tensors(output)))
1591            else:
1592                self.reducer.prepare_for_backward([])
1593        else:
1594            self.require_forward_param_sync = False
1595
1596        # TODO: DDPSink is currently enabled for unused parameter detection and
1597        # static graph training for first iteration.
1598        if (self.find_unused_parameters and not self.static_graph) or (
1599            self.static_graph and not self._static_graph_delay_allreduce_enqueued
1600        ):
1601            (
1602                output_tensor_list,
1603                treespec,
1604                output_is_rref,
1605            ) = _tree_flatten_with_rref(output)
1606            output_placeholders: List[Optional[torch.Tensor]] = [
1607                None for _ in range(len(output_tensor_list))
1608            ]
1609            # Do not touch tensors that have no grad_fn, which can cause issues
1610            # such as https://github.com/pytorch/pytorch/issues/60733
1611            for i, output in enumerate(output_tensor_list):
1612                if torch.is_tensor(output) and output.grad_fn is None:
1613                    output_placeholders[i] = output
1614
1615            # When find_unused_parameters=True, makes tensors which require grad
1616            # run through the DDPSink backward pass. When not all outputs are
1617            # used in loss, this makes those corresponding tensors receive
1618            # undefined gradient which the reducer then handles to ensure
1619            # param.grad field is not touched and we don't error out.
1620            passthrough_tensor_list = _DDPSink.apply(
1621                weakref.ref(self),
1622                *output_tensor_list,
1623            )
1624            for i in range(len(output_placeholders)):
1625                if output_placeholders[i] is None:
1626                    output_placeholders[i] = passthrough_tensor_list[i]
1627
1628            # Reconstruct output data structure.
1629            output = _tree_unflatten_with_rref(
1630                output_placeholders, treespec, output_is_rref
1631            )
1632
1633        # At the end of the forward pass, reset the grad buffer and grad views
1634        self._clear_grad_buffer()
1635        return output
1636
1637    def forward(self, *inputs, **kwargs):
1638        with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
1639            inputs, kwargs = self._pre_forward(*inputs, **kwargs)
1640            output = (
1641                self.module.forward(*inputs, **kwargs)
1642                if self._delay_all_reduce_all_params
1643                else self._run_ddp_forward(*inputs, **kwargs)
1644            )
1645            return self._post_forward(output)
1646
1647    def scatter(self, inputs, kwargs, device_ids):
1648        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
1649
1650    def to_kwargs(self, inputs, kwargs, device_id):
1651        # Kept for BC
1652        return _to_kwargs(
1653            inputs,
1654            kwargs,
1655            torch.device(self.device_type, device_id),
1656            self.use_side_stream_for_tensor_copies,
1657        )
1658
1659    def gather(self, outputs, output_device):
1660        return gather(outputs, output_device, dim=self.dim)
1661
1662    def train(self, mode=True):
1663        super().train(mode)
1664        return self
1665
1666    # When running in join mode, schedules an allreduce to notify joined ranks
1667    # of whether backwards pass synchronization will run this iteration or not.
1668    def _check_global_requires_backward_grad_sync(self, is_joined_rank):
1669        if not is_joined_rank and self.require_backward_grad_sync:
1670            requires_sync_tensor = torch.ones(1, device=self.device)
1671        else:
1672            requires_sync_tensor = torch.zeros(1, device=self.device)
1673
1674        work = dist.all_reduce(
1675            requires_sync_tensor, group=self.process_group, async_op=True
1676        )
1677
1678        # (kwen2501) This if condition is a plain translation of previous
1679        # behavior, i.e. in the `is_joined_rank=False` case, `work.wait()`
1680        # is not called and it doesn't care about the result. I am guessing
1681        # that it just wants to fire a matching all-reduce and does not want
1682        # the main stream to wait.
1683        if is_joined_rank:
1684            work.wait()
1685            should_sync_backwards = requires_sync_tensor.item() != 0
1686            return should_sync_backwards
1687        else:
1688            return None  # Return value is not/should not be used.
1689
1690    # When running in join mode, checks and performs sync of module buffers if
1691    # the models have buffers that should be synchronized in the forward pass.
1692    def _check_and_sync_module_buffers(self):
1693        if self._check_sync_bufs_pre_fwd():
1694            authoritative_rank = self._find_common_rank(self._distributed_rank, False)
1695            self._sync_module_buffers(authoritative_rank)
1696
1697    # When running in join model, agrees upon a common rank and broadcast model
1698    # parameters to all other ranks.
1699    def _sync_final_model(self, is_last_joiner):
1700        # Agree upon the process that will be the authoritative model copy.
1701        # The current rank is a candidate for being the authoritative copy if
1702        # is_last_joiner=True. We break ties via picking the larger rank.
1703        self._authoritative_rank = self._find_common_rank(
1704            self._distributed_rank, is_last_joiner
1705        )
1706        _sync_module_states(
1707            module=self.module,
1708            process_group=self.process_group,
1709            broadcast_bucket_size=self.broadcast_bucket_size,
1710            src=self._authoritative_rank,
1711            params_and_buffers_to_ignore=self.parameters_to_ignore,
1712            broadcast_buffers=self.broadcast_buffers,
1713        )
1714
1715    # Schedule comm ops to match those scheduled in the reducer's backward
1716    # pass.
1717    def _match_all_reduce_for_bwd_pass(self):
1718        comm_work = []
1719        # Schedule comm in the same order as Reducer schedules them, i.e.
1720        # the order of the buckets. Retrieving the bucket order from the reducer
1721        # ensures that we keep the same order in join mode, such as when bucket
1722        # order is rebuilt dynamically.
1723
1724        # Returns grad_buckets in order, but real tensors are substituted with
1725        # zero tensors of the same shape.
1726        grad_buckets = self.reducer._get_zeros_like_grad_buckets()
1727        for grad_bucket in grad_buckets:
1728            # Joined processes contribute zero gradient. In the case that
1729            # divide_by_initial_world_size=True, we divide grads by the static
1730            # world size, if not, the dividing factor is reduced by the number
1731            # of joined processes.
1732            work = self.reducer._run_comm_hook(grad_bucket)
1733            comm_work.append(work)
1734        for work in comm_work:
1735            work.wait()
1736
1737    # Allreduces the used parameter mapping across ranks.
1738    def _match_unused_params_allreduce(self):
1739        locally_used_param_map = self.reducer._get_local_used_map()
1740        self.process_group.allreduce(locally_used_param_map)
1741
1742    def join(
1743        self,
1744        divide_by_initial_world_size: bool = True,
1745        enable: bool = True,
1746        throw_on_early_termination: bool = False,
1747    ):
1748        r"""
1749        Context manager for training with uneven inputs across processes in DDP.
1750
1751        This context manager will keep track of already-joined DDP processes,
1752        and "shadow" the forward and backward passes by inserting collective
1753        communication operations to match with the ones created by non-joined
1754        DDP processes. This will ensure each collective call has a corresponding
1755        call by already-joined DDP processes, preventing hangs or errors that
1756        would otherwise happen when training with uneven inputs across
1757        processes. Alternatively, if the flag ``throw_on_early_termination`` is
1758        specified to be ``True``, all trainers will throw an error once one rank
1759        runs out of inputs, allowing these errors to be caught and handled
1760        according to application logic.
1761
1762        Once all DDP processes have joined, the context manager will broadcast
1763        the model corresponding to the last joined process to all processes to
1764        ensure the model is the same across all processes
1765        (which is guaranteed by DDP).
1766
1767        To use this to enable training with uneven inputs across processes,
1768        simply wrap this context manager around your training loop. No further
1769        modifications to the model or data loading is required.
1770
1771        .. warning::
1772            If the model or training loop this context manager is wrapped around
1773            has additional distributed collective operations, such as
1774            ``SyncBatchNorm`` in the model's forward pass, then the flag
1775            ``throw_on_early_termination`` must be enabled. This is because this
1776            context manager is not aware of non-DDP collective communication.
1777            This flag will cause all ranks to throw when any one rank
1778            exhausts inputs, allowing these errors to be caught and recovered
1779            from across all ranks.
1780
1781        Args:
1782            divide_by_initial_world_size (bool): If ``True``, will divide
1783                gradients by the initial ``world_size`` DDP training was launched
1784                with. If ``False``, will compute the effective world size
1785                (number of ranks that have not depleted their inputs yet) and
1786                divide gradients by that during allreduce. Set
1787                ``divide_by_initial_world_size=True`` to ensure every input
1788                sample including the uneven inputs have equal weight in terms of
1789                how much they contribute to the global gradient. This is
1790                achieved by always dividing the gradient by the initial
1791                ``world_size`` even when we encounter uneven inputs. If you set
1792                this to ``False``, we divide the gradient by the remaining
1793                number of nodes. This ensures parity with training on a smaller
1794                ``world_size`` although it also means the uneven inputs would
1795                contribute more towards the global gradient. Typically, you
1796                would want to set this to ``True`` for cases where the last few
1797                inputs of your training job are uneven. In extreme cases, where
1798                there is a large discrepancy in the number of inputs, setting
1799                this to ``False`` might provide better results.
1800            enable (bool): Whether to enable uneven input detection or not. Pass
1801                in ``enable=False`` to disable in cases where you know that
1802                inputs are even across participating processes. Default is
1803                ``True``.
1804            throw_on_early_termination (bool): Whether to throw an error
1805                or continue training when at least one rank has exhausted
1806                inputs. If ``True``, will throw upon the first rank reaching end
1807                of data. If ``False``, will continue training with a smaller
1808                effective world size until all ranks are joined. Note that if
1809                this flag is specified, then the flag
1810                ``divide_by_initial_world_size`` would be ignored. Default
1811                is ``False``.
1812
1813
1814        Example::
1815
1816            >>> # xdoctest: +SKIP("Distributed")
1817            >>> import torch
1818            >>> import torch.distributed as dist
1819            >>> import os
1820            >>> import torch.multiprocessing as mp
1821            >>> import torch.nn as nn
1822            >>> # On each spawned worker
1823            >>> def worker(rank):
1824            >>>     dist.init_process_group("nccl", rank=rank, world_size=2)
1825            >>>     torch.cuda.set_device(rank)
1826            >>>     model = nn.Linear(1, 1, bias=False).to(rank)
1827            >>>     model = torch.nn.parallel.DistributedDataParallel(
1828            >>>         model, device_ids=[rank], output_device=rank
1829            >>>     )
1830            >>>     # Rank 1 gets one more input than rank 0.
1831            >>>     inputs = [torch.tensor([1]).float() for _ in range(10 + rank)]
1832            >>>     with model.join():
1833            >>>         for _ in range(5):
1834            >>>             for inp in inputs:
1835            >>>                 loss = model(inp).sum()
1836            >>>                 loss.backward()
1837            >>>     # Without the join() API, the below synchronization will hang
1838            >>>     # blocking for rank 1's allreduce to complete.
1839            >>>     torch.cuda.synchronize(device=rank)
1840        """
1841        return Join(
1842            [self],
1843            enable,
1844            throw_on_early_termination,
1845            divide_by_initial_world_size=divide_by_initial_world_size,
1846        )
1847
1848    def join_hook(
1849        self,
1850        **kwargs,
1851    ):
1852        r"""
1853        DDP join hook enables training on uneven inputs by mirroring communications in forward and backward passes.
1854
1855        Arguments:
1856            kwargs (dict): a :class:`dict` containing any keyword arguments
1857                to modify the behavior of the join hook at run time; all
1858                :class:`Joinable` instances sharing the same join context
1859                manager are forwarded the same value for ``kwargs``.
1860
1861        The hook supports the following keyword arguments:
1862            divide_by_initial_world_size (bool, optional):
1863                If ``True``, then gradients are divided by the initial world
1864                size that DDP was launched with.
1865                If ``False``, then gradients are divided by the effective world
1866                size (i.e. the number of non-joined processes), meaning that
1867                the uneven inputs contribute more toward the global gradient.
1868                Typically, this should be set to ``True`` if the degree of
1869                unevenness is small but can be set to ``False`` in extreme
1870                cases for possibly better results.
1871                Default is ``True``.
1872        """
1873        divide_by_initial_world_size = kwargs.get("divide_by_initial_world_size", True)
1874        return _DDPJoinHook(
1875            self, divide_by_initial_world_size=divide_by_initial_world_size
1876        )
1877
1878    @property
1879    def join_device(self):
1880        return self.device
1881
1882    @property
1883    def join_process_group(self):
1884        return self.process_group
1885
1886    def _register_buffer_comm_hook(
1887        self,
1888        state,
1889        hook: Callable,
1890        comm_hook_location=_BufferCommHookLocation.POST_FORWARD,
1891    ):
1892        r"""
1893        Allow custom registration of hooks that define how buffer are synchronized across ranks.
1894
1895        The hook takes in an optional state and is passed in a Dict[str, Tensor]
1896        corresponding to buffer names and the buffers, and can run arbitrary reductions
1897        on buffers as opposed to DDP's default broadcast from rank 0. This is useful for
1898        example if a counter needs to be summed or averaged across ranks every iteration.
1899
1900        Args:
1901            state (Any): Optional state that is passed to the hook.
1902            hook (Callable): Callable with the following signature:
1903                         ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``
1904            comm_hook_location (_BufferCommHookLocation): Enum value indicating
1905                            where to run the hook.
1906                            _BufferCommHookLocation.PRE_FORWARD means that the
1907                            hook will run _before_ the forward pass, and
1908                            _BufferCommHookLocation.POST_FORWARD means that the
1909                            hook will run _after_ the forward pass.
1910
1911            NOTE: To maximize performance, users can return a
1912                List[torch.futures.Future] from their hook, and DDP will
1913                install and await these hooks appropriately at the end of
1914                the backward pass. This will ensure all buffers are
1915                synchronized by the end of the backward pass. If this
1916                setting is used, it is recommended to pass
1917                comm_hook_location=_BufferCommHookLocation.POST_FORWARD,
1918                which will trigger the hook after the forward pass.
1919                If _BufferCommHookLocation.PRE_FORWARD is used, users must
1920                ensure appropriate synchronization when manipulating GPU
1921                buffers in the forward pass.
1922        """
1923        assert callable(hook)
1924        self.buffer_hook = _BufferCommHook(
1925            buffer_comm_hook=hook,
1926            buffer_comm_hook_state=state,
1927            buffer_comm_hook_location=comm_hook_location,
1928        )
1929
1930    def register_comm_hook(self, state: object, hook: Callable):
1931        r"""
1932        Register communication hook for user-defined DDP aggregation of gradients across multiple workers.
1933
1934        This hook would be very useful for researchers to try out new ideas. For
1935        example, this hook can be used to implement several algorithms like GossipGrad
1936        and gradient compression which involve different communication strategies for
1937        parameter syncs while running Distributed DataParallel training.
1938
1939        Args:
1940            state (object): Passed to the hook to maintain any state information during the training process.
1941                            Examples include error feedback in gradient compression,
1942                            peers to communicate with next in GossipGrad, etc.
1943
1944                            It is locally stored by each worker
1945                            and shared by all the gradient tensors on the worker.
1946            hook (Callable): Callable with the following signature:
1947                             ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``:
1948
1949                             This function is called once the bucket is ready. The
1950                             hook can perform whatever processing is needed and return
1951                             a Future indicating completion of any async work (ex: allreduce).
1952                             If the hook doesn't perform any communication, it still
1953                             must return a completed Future. The Future should hold the
1954                             new value of grad bucket's tensors. Once a bucket is ready,
1955                             c10d reducer would call this hook and use the tensors returned
1956                             by the Future and copy grads to individual parameters.
1957                             Note that the future's return type must be a single tensor.
1958
1959                             We also provide an API called ``get_future`` to retrieve a
1960                             Future associated with the completion of ``c10d.ProcessGroup.Work``.
1961                             ``get_future`` is currently supported for NCCL and also supported for most
1962                             operations on GLOO and MPI, except for peer to peer operations (send/recv).
1963
1964        .. warning ::
1965            Grad bucket's tensors will not be predivided by world_size. User is responsible
1966            to divide by the world_size in case of operations like allreduce.
1967
1968        .. warning ::
1969            DDP communication hook can only be registered once and should be registered
1970            before calling backward.
1971
1972        .. warning ::
1973            The Future object that hook returns should contain a single tensor
1974            that has the same shape with the tensors inside grad bucket.
1975
1976        .. warning ::
1977            ``get_future`` API supports NCCL, and partially GLOO and MPI backends (no support
1978            for peer-to-peer operations like send/recv) and will return a ``torch.futures.Future``.
1979
1980        Example::
1981            Below is an example of a noop hook that returns the same tensor.
1982
1983            >>> # xdoctest: +SKIP('undefined name')
1984            >>> def noop(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
1985            >>>     fut = torch.futures.Future()
1986            >>>     fut.set_result(bucket.buffer())
1987            >>>     return fut
1988            >>> ddp.register_comm_hook(state=None, hook=noop)
1989
1990        Example::
1991            Below is an example of a Parallel SGD algorithm where gradients are encoded before
1992            allreduce, and then decoded after allreduce.
1993
1994            >>> # xdoctest: +SKIP('undefined name')
1995            >>> def encode_and_decode(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
1996            >>>     encoded_tensor = encode(bucket.buffer())  # encode gradients
1997            >>>     fut = torch.distributed.all_reduce(encoded_tensor).get_future()
1998            >>>     # Define the then callback to decode.
1999            >>>     def decode(fut):
2000            >>>         decoded_tensor = decode(fut.value()[0])  # decode gradients
2001            >>>         return decoded_tensor
2002            >>>     return fut.then(decode)
2003            >>> ddp.register_comm_hook(state=None, hook=encode_and_decode)
2004        """
2005        self._check_comm_hook(hook)
2006        assert self.logger is not None
2007        self.logger._set_comm_hook_name(hook.__qualname__)
2008        self._comm_hooks.append((hook, state))
2009        dist._register_comm_hook(self.reducer, state, hook)
2010
2011    def _register_builtin_comm_hook(self, comm_hook_type):
2012        r"""
2013        Register a built-in communication hook that specifies how DDP aggregates gradients across multiple workers.
2014
2015        The built-in hooks aim to provide efficient C++ implementations for certain hooks,
2016        which might not be as efficient if implemented in Python using a Python communication hook.
2017
2018        Args:
2019            comm_hook_type (dist.BuiltinCommHookType): type of communication hook, such as ALLREDUCE, FP16_COMPRESS, etc.
2020
2021        .. warning ::
2022            DDP communication hook can only be registered once and should be registered
2023            before calling backward.
2024
2025        Example::
2026            Below is an example of a FP16 compression where gradients are
2027            compressed into 16-bit floating-point numbers before allreduce, and
2028            then decompressed after allreduce.
2029
2030            >>> # xdoctest: +SKIP('undefined name')
2031            >>> ddp._register_builtin_comm_hook(dist.BuiltinCommHookType.FP16_COMPRESS)
2032
2033        """
2034        assert self.logger is not None
2035        self.logger._set_comm_hook_name(str(comm_hook_type))
2036        dist._register_builtin_comm_hook(self.reducer, comm_hook_type)
2037
2038    def _register_fused_optim(self, optim: Type, *args, optim_params=None, **kwargs):
2039        r"""
2040        Register an optimizer in DDP to optimize parameter immediately after its gradient reduction.
2041
2042        Registers an optimizer with DDP such that the optimization for a
2043        parameter will run immediately when that parameter's gradient is
2044        finished with reduction, instead of waiting for all parameters'
2045        gradients to finish reduction. This can result in a training speedup
2046        depending on your workload since the optimizer can run while gradient
2047        reduction for other parameters are still ongoing. In addition, this has
2048        the potential to reduce peak memory consumption during training, as it
2049        only needs to load the per-parameter optimizer states of a single
2050        parameter at a time, instead of loading all per-parameter optimizer
2051        states at once.
2052
2053        Args:
2054            optim (Type): a ``torch.optim.Optimizer`` class to be registered
2055            as a fused optimizer.
2056            *args (Sequence[Any]): Arguments to forward to `optim`.
2057            optim_params (Optional[Iterable[torch.Tensor]]): Set of parameters
2058            to optimize, similar to `params` argument of traditional `torch.optim`
2059            Optimizers. If this is omitted, all DDP model parameters will be
2060            optimized.
2061            **kwargs: (Dict[str, Any]): Keyword arguments to forward to `optim`.
2062
2063        .. warning ::
2064            _register_fused_optim should only be called once on a DDP instance,
2065            and registering multiple fused optimizers for the same DDP model
2066            is not currently supported. Please ping
2067            https://github.com/pytorch/pytorch/issues/71595 if this is necessary
2068            for your use case.
2069
2070        .. warning ::
2071            _register_fused_optim and register_comm_hook currently do not
2072            compose together, meaning that custom DDP communication hooks are
2073            not supported with overlapped optimizers. Please ping
2074            https://github.com/pytorch/pytorch/issues/71595 if this is necessary
2075            for your use case.
2076
2077        .. warning ::
2078            Gradient accumulation and DDP `no_sync` are currently not supported
2079            with overlapped optimizer. Please ping
2080            https://github.com/pytorch/pytorch/issues/71595 if this is necessary
2081            for your use case.
2082
2083        Example::
2084
2085            >>> # xdoctest: +SKIP("No rendezvous handler")
2086            >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
2087            >>> net = torch.nn.parallel.DistributedDataParallel(model, pg)
2088            >>> lr = 1e-2
2089            >>> betas = (0.9, 0.99)
2090            >>> eps = 1e-6
2091            >>> net._register_fused_optim(torch.optim.Adam, lr, betas=betas, eps=eps)
2092            >>> # Example with subset of parameters
2093            >>> params_to_opt = [list(net.parameters())[0]]
2094            >>> net._register_fused_optim(
2095            ...   torch.optim.Adam, lr, optim_params=params_to_opt,  betas=betas, eps=eps
2096            ... )
2097        """
2098        # Note: importing in function, otherwise this will cause a circular
2099        # import as optimizer_overlap module needs to import DistributedDataParallel.
2100        from torch.distributed.algorithms._optimizer_overlap import _as_overlapped_optim
2101
2102        overlapped_optim = _as_overlapped_optim(optim, optim_params, *args, **kwargs)
2103        try:
2104            overlapped_optim.register_ddp(self)
2105        except NotImplementedError as e:
2106            raise RuntimeError(
2107                f"{optim} does not support overlapped DDP. Please file an issue to PyTorch or the respective owner of {optim}."
2108            ) from e
2109
2110    def _distributed_broadcast_coalesced(
2111        self, tensors, buffer_size, authoritative_rank=0
2112    ):
2113        dist._broadcast_coalesced(
2114            self.process_group, tensors, buffer_size, authoritative_rank
2115        )
2116
2117    def _check_sync_bufs_post_fwd(self):
2118        return (
2119            self.will_sync_module_buffers()
2120            and hasattr(self, "buffer_hook")
2121            and self.buffer_hook.buffer_comm_hook_location
2122            == _BufferCommHookLocation.POST_FORWARD
2123        )
2124
2125    def _check_sync_bufs_pre_fwd(self):
2126        return self.will_sync_module_buffers() and (
2127            not hasattr(self, "buffer_hook")
2128            or self.buffer_hook.buffer_comm_hook_location
2129            == _BufferCommHookLocation.PRE_FORWARD
2130        )
2131
2132    def will_sync_module_buffers(self):
2133        return (
2134            self.require_forward_param_sync
2135            and self.broadcast_buffers
2136            and len(self.modules_buffers) > 0
2137        )
2138
2139    def _find_common_rank(self, input_rank, rank_cond):
2140        # -1 indicates that this rank is not under consideration to be the
2141        # common_rank
2142        rank_to_use = torch.tensor(
2143            [input_rank if rank_cond else -1],
2144            device=self.device,
2145        )
2146        dist.all_reduce(rank_to_use, op=ReduceOp.MAX, group=self.process_group)
2147        if rank_to_use.item() == -1:
2148            self._log_and_throw(
2149                ValueError,
2150                "BUG! Expected rank_cond to be true for at least one process."
2151                " This indicates a bug in PyTorch, please report an issue.",
2152            )
2153        return rank_to_use.item()
2154
2155    def _sync_buffers(self):
2156        with torch.no_grad():
2157            # module buffer sync
2158            # Synchronize buffers across processes.
2159            # If we are running DDP with the join manager, we have to agree
2160            # upon a rank to sync module buffers from, since rank 0 may
2161            # already have been joined and have stale module buffers.
2162            if self._join_config.enable:
2163                authoritative_rank = self._find_common_rank(
2164                    self._distributed_rank, True
2165                )
2166            else:
2167                # The process with rank 0 is considered the authoritative copy.
2168                authoritative_rank = 0
2169            # Update self.modules_buffers incase any buffers were
2170            # reassigned.
2171            self._assign_modules_buffers()
2172            self._sync_module_buffers(authoritative_rank)
2173
2174    def _sync_module_buffers(self, authoritative_rank):
2175        if not hasattr(self, "buffer_hook"):
2176            self._default_broadcast_coalesced(authoritative_rank=authoritative_rank)
2177        else:
2178            hook = self.buffer_hook.buffer_comm_hook
2179            state = self.buffer_hook.buffer_comm_hook_state
2180            futs = hook(state, self.named_module_buffers)
2181            if futs is not None:
2182                self.reducer._install_post_backward_futures(futs)
2183
2184    def _default_broadcast_coalesced(
2185        self, bufs=None, bucket_size=None, authoritative_rank=0
2186    ):
2187        """
2188        Broadcasts buffers from rank 0 to rest of workers.
2189
2190        If bufs, bucket_size are None, default values self.modules_buffers
2191        and self.broadcast_bucket_size are used instead.
2192        """
2193        if bufs is None:
2194            bufs = self.modules_buffers
2195        if bucket_size is None:
2196            bucket_size = self.broadcast_bucket_size
2197
2198        self._distributed_broadcast_coalesced(bufs, bucket_size, authoritative_rank)
2199
2200    def _passing_sync_batchnorm_handle(self, module):
2201        for layer in module.modules():
2202            if isinstance(layer, torch.nn.modules.SyncBatchNorm):
2203                if self.device_type == "cpu":
2204                    self._log_and_throw(
2205                        ValueError,
2206                        "SyncBatchNorm layers only work with GPU modules",
2207                    )
2208
2209    def _check_comm_hook(self, hook):
2210        if not callable(hook):
2211            self._log_and_throw(TypeError, "Communication hook must be callable.")
2212
2213        sig = inspect.signature(hook)
2214        if (
2215            sig.parameters["bucket"].annotation != inspect._empty
2216            and sig.parameters["bucket"].annotation != dist.GradBucket
2217        ):
2218            self._log_and_throw(
2219                ValueError,
2220                "Communication hook: bucket annotation should be dist.GradBucket.",
2221            )
2222
2223        if (
2224            sig.return_annotation != inspect._empty
2225            and sig.return_annotation != torch.futures.Future[torch.Tensor]
2226        ):
2227            self._log_and_throw(
2228                ValueError,
2229                "Communication hook: return annotation should be torch.futures.Future[torch.Tensor].",
2230            )
2231
2232        if hook.__name__ in [
2233            "bf16_compress_hook",
2234            "bf16_compress_wrapper_hook",
2235        ] and (
2236            (torch.version.cuda is None and torch.version.hip is None)
2237            or (
2238                torch.version.cuda is not None
2239                and int(torch.version.cuda.split(".")[0]) < 11
2240            )
2241            or not dist.is_available()
2242            or not dist.is_nccl_available()
2243            or torch.cuda.nccl.version() < (2, 10)
2244        ):
2245            self._log_and_throw(
2246                TypeError,
2247                "BF16 all reduce communication hook required CUDA 11+ and NCCL 2.10+.",
2248            )
2249
2250    @property
2251    def _distributed_rank(self):
2252        return dist.get_rank(self.process_group)
2253
2254    @staticmethod
2255    def _get_data_parallel_params(module, named_params=False):
2256        """Return a generator of parameters managed by a given DDP unit."""
2257        for param in (
2258            module.parameters() if not named_params else module.named_parameters()
2259        ):
2260            if not hasattr(param, "_ddp_ignored"):
2261                yield param
2262
2263    @staticmethod
2264    def _set_params_and_buffers_to_ignore_for_model(
2265        module, params_and_buffers_to_ignore
2266    ):
2267        """
2268        Set parameters and buffers to be ignored by DDP.
2269
2270        Expected format for parameters is the fully qualified name: {module_name}.{param_name}, and
2271        similarly, {module_name}.{buffer_name} for buffers. For example:
2272        params_to_ignore = []
2273        # NB: model here is vanilla PyTorch module, not yet wrapped with DDP.
2274        for module_name, module in model.named_modules():
2275            for param_name, param in module.named_parameters(recurse=False):
2276                if should_ignore(param):
2277                    # Create expected format
2278                    fqn = f"{module_name}.{param_name}"
2279                    params_to_ignore.append(fqn)
2280        torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
2281            model,
2282            params_to_ignore
2283        )
2284        """
2285        # This is a workaround to set parameters and buffers DDP should ignore
2286        # during synchronization. It will be removed when the API is finalized
2287        # as part of addressing https://github.com/pytorch/pytorch/issues/43690.
2288        module._ddp_params_and_buffers_to_ignore = params_and_buffers_to_ignore
2289        for name, param in module.named_parameters():
2290            if name in params_and_buffers_to_ignore:
2291                param._ddp_ignored = True
2292        for name, buffer in module.named_buffers():
2293            if name in params_and_buffers_to_ignore:
2294                buffer._ddp_ignored = True
2295
2296    def _get_ddp_logging_data(self):
2297        r"""
2298        Return a dictionary of logging data for debugging and analysis.
2299
2300        This interface can be called after DistributedDataParallel() is
2301        constructed. It returns a dictionary of logging data. It could help
2302        for debugging and analysis. The logging data includes DistributedDataParallel
2303        constructor input parameters, some internal states of DistributedDataParallel
2304        and performance metrics. Simply print the dictionary and see what
2305        these metrics are.
2306        This is a prototype interface and subject to change in the future.
2307        """
2308        assert self.logger is not None
2309        ddp_logging_data = self.logger._get_ddp_logging_data()
2310        return {**ddp_logging_data.strs_map, **ddp_logging_data.ints_map}
2311
2312    def _set_ddp_runtime_logging_sample_rate(self, sample_rate):
2313        r"""
2314        Set sample_rate of collecting runtime stats.
2315
2316        This interface allows users to set sample_rate of collecting
2317        runtime stats. The runtime stats will be recorded for the
2318        first 10 iterations, after 10 iterations runtime stats will be
2319        recorded once every "sample_rate" training iterations. In
2320        default, runtime stats are recorded for the first 10 iterations,
2321        after 10 iterations runtime stats are recorded once every
2322        "kDDPRuntimeLoggingSampleRate=100" training iterations.
2323        This is a prototype interface and subject to change in the future.
2324        """
2325        if sample_rate < 1:
2326            self._log_and_throw(
2327                ValueError,
2328                "DDP runtime logging sample rate should be equal or greater than 1",
2329            )
2330        self.reducer._set_ddp_runtime_logging_sample_rate(sample_rate)
2331
2332    def _set_static_graph(self):
2333        """
2334        Set static graph for DDP.
2335
2336        It is recommended to set static graph in the DDP constructor, which will
2337        call this private API internally.
2338        """
2339        # If self.static_graph has been set, no need to set it again
2340        if self.static_graph:
2341            warnings.warn(
2342                "You've set static_graph to be True, no need to set it again."
2343            )
2344            return
2345        self.static_graph = True
2346        self._static_graph_delay_allreduce_enqueued = False
2347        self.reducer._set_static_graph()
2348        assert self.logger is not None
2349        self.logger._set_static_graph()
2350        if self.find_unused_parameters:
2351            warnings.warn(
2352                "You passed find_unused_parameters=true to DistributedDataParallel, "
2353                "`_set_static_graph` will detect unused parameters automatically, so "
2354                "you do not need to set find_unused_parameters=true, just be sure these "
2355                "unused parameters will not change during training loop while calling "
2356                "`_set_static_graph`."
2357            )
2358
2359    def _remove_autograd_hooks(self):
2360        """Remove autograd hooks registered by the reducer on the model parameters."""
2361        self.reducer._remove_autograd_hooks()
2362
2363    def _check_reducer_finalized(self):
2364        """
2365        Check if the reducer has processed all buckets and finalized the backward appropriately.
2366
2367        It is useful to call this method after calling .backward() in your training loop
2368        in order to avoid subsequent hard to debug errors down the road due to the
2369        reducer not finalizing backward.
2370        """
2371        self.reducer._check_reducer_finalized()
2372
2373    def _set_sparse_metadata(self, global_unique_ids):
2374        self.reducer._set_sparse_metadata(global_unique_ids)
2375
2376    def _update_process_group(self, new_process_group):
2377        """
2378        Dynamically updates the process group for DDP so that we can shrink/expand DDP
2379        world size without having to reinitialize DDP.
2380
2381        NOTE: If you are using custom communications hooks via, register_comm_hook,
2382        you need to update the process groups for those hooks separately.
2383        """
2384        # Force a rebuild of buckets for a new process group. This ensures all ranks
2385        # are synchronized in terms of when they will rebuild buckets and also
2386        # re-evaluates previous assumptions of buckets given the world size might have
2387        # changed.
2388        self._has_rebuilt_buckets = False
2389        self.reducer._reset_state()
2390
2391        if not _rank_not_in_group(new_process_group):
2392            self.process_group = new_process_group
2393            self.reducer._update_process_group(new_process_group)
2394
2395    def _set_ddp_sink_clone(self, val: bool):
2396        """
2397        Sets whether or not DDPSink should clone the output tensors or not.
2398        The default is True since if the loss is modified in place we run
2399        into the view is modified in-place error.
2400
2401        Although, cloning the tensors can add significant memory and
2402        performance hit if the number and size of tensors are large. As
2403        a result, this can be set to False if you are not modifying the
2404        loss in place.
2405        """
2406        self._ddp_sink_clone = val
2407