xref: /aosp_15_r20/external/pytorch/torch/distributed/distributed_c10d.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Distributed Collective Communication (c10d)."""
3
4import collections.abc
5import contextlib
6import hashlib
7import io
8import itertools
9import logging
10import os
11import pickle
12import sys
13import time
14import warnings
15from collections import namedtuple
16from datetime import timedelta
17from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
18from typing_extensions import deprecated
19
20import torch
21from torch._C import _DistStoreError as DistStoreError
22from torch._C._distributed_c10d import (
23    _DistributedBackendOptions,
24    _register_process_group,
25    _resolve_process_group,
26    _unregister_all_process_groups,
27    _unregister_process_group,
28    AllgatherOptions,
29    AllreduceCoalescedOptions,
30    AllreduceOptions,
31    AllToAllOptions,
32    BarrierOptions,
33    BroadcastOptions,
34    DebugLevel,
35    GatherOptions,
36    get_debug_level,
37    PrefixStore,
38    ProcessGroup,
39    ReduceOp,
40    ReduceOptions,
41    ReduceScatterOptions,
42    ScatterOptions,
43    Store,
44    Work,
45)
46from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs
47from torch.utils._typing_utils import not_none
48
49from .c10d_logger import _exception_logger, _time_logger
50from .constants import default_pg_nccl_timeout, default_pg_timeout
51from .rendezvous import register_rendezvous_handler, rendezvous  # noqa: F401
52
53
54__all__ = [
55    "Backend",
56    "BackendConfig",
57    "GroupMember",
58    "P2POp",
59    "all_gather",
60    "all_gather_coalesced",
61    "all_gather_object",
62    "all_reduce",
63    "all_reduce_coalesced",
64    "all_to_all",
65    "all_to_all_single",
66    "barrier",
67    "batch_isend_irecv",
68    "broadcast",
69    "send_object_list",
70    "recv_object_list",
71    "broadcast_object_list",
72    "destroy_process_group",
73    "gather",
74    "gather_object",
75    "get_backend_config",
76    "get_backend",
77    "get_rank",
78    "get_world_size",
79    "get_pg_count",
80    "group",
81    "init_process_group",
82    "irecv",
83    "is_gloo_available",
84    "is_initialized",
85    "is_mpi_available",
86    "is_backend_available",
87    "is_nccl_available",
88    "is_torchelastic_launched",
89    "is_ucc_available",
90    "isend",
91    "monitored_barrier",
92    "new_group",
93    "new_subgroups",
94    "new_subgroups_by_enumeration",
95    "recv",
96    "reduce",
97    "reduce_scatter",
98    "scatter",
99    "scatter_object_list",
100    "send",
101    "supports_complex",
102    "AllreduceCoalescedOptions",
103    "AllreduceOptions",
104    "AllToAllOptions",
105    "BarrierOptions",
106    "BroadcastOptions",
107    "GatherOptions",
108    "PrefixStore",
109    "ProcessGroup",
110    "ReduceOp",
111    "ReduceOptions",
112    "ReduceScatterOptions",
113    "ScatterOptions",
114    "Store",
115    "DebugLevel",
116    "get_debug_level",
117    "Work",
118    "default_pg_timeout",
119    "get_group_rank",
120    "get_global_rank",
121    "get_process_group_ranks",
122    "reduce_op",
123    "all_gather_into_tensor",
124    "reduce_scatter_tensor",
125    "get_node_local_rank",
126    "split_group",
127]
128
129_MPI_AVAILABLE = True
130_NCCL_AVAILABLE = True
131_GLOO_AVAILABLE = True
132_UCC_AVAILABLE = True
133
134_pickler = pickle.Pickler
135_unpickler = pickle.Unpickler
136
137
138# Change __module__ of all imported types from torch._C._distributed_c10d that are public
139def _export_c_types() -> None:
140    _public_types_to_change_module = [
141        AllreduceCoalescedOptions,
142        AllreduceOptions,
143        AllToAllOptions,
144        BarrierOptions,
145        BroadcastOptions,
146        GatherOptions,
147        PrefixStore,
148        ProcessGroup,
149        ReduceOp,
150        ReduceOptions,
151        ReduceScatterOptions,
152        ScatterOptions,
153        Store,
154        DebugLevel,
155        get_debug_level,
156        Work,
157    ]
158    for type in _public_types_to_change_module:
159        type.__module__ = "torch.distributed.distributed_c10d"
160
161
162_export_c_types()
163
164try:
165    from torch._C._distributed_c10d import ProcessGroupMPI
166
167    ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d"
168    __all__ += ["ProcessGroupMPI"]
169except ImportError:
170    _MPI_AVAILABLE = False
171
172try:
173    from torch._C._distributed_c10d import ProcessGroupNCCL
174
175    ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d"
176    __all__ += ["ProcessGroupNCCL"]
177except ImportError:
178    _NCCL_AVAILABLE = False
179
180try:
181    from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo
182
183    ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d"
184    __all__ += ["ProcessGroupGloo"]
185except ImportError:
186    _GLOO_AVAILABLE = False
187
188try:
189    from torch._C._distributed_c10d import ProcessGroupUCC
190
191    ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d"
192    __all__ += ["ProcessGroupUCC"]
193except ImportError:
194    _UCC_AVAILABLE = False
195
196logger = logging.getLogger(__name__)
197
198PG_WRAPPER_STORE_PREFIX = "pg_wrapper"
199
200
201# Some reduce ops are not supported by complex numbers and will result in an error.
202# We currently provide complex support to the distributed API by viewing
203# complex tensors as real (torch.view_as_real), meaning that calling
204# these unsupported ops will return garbage values rather than error out.
205# (e.g. max(2+3i, 3+2i) = 3+3i)
206# We'd like calls to unsupported ops to error out accordingly,
207# rather than returning garbage values.
208def supports_complex(reduceOp: ReduceOp) -> bool:
209    """Return true if reduce ops is supported. False otherwise."""
210    denyList = [
211        ReduceOp.MAX,
212        ReduceOp.MIN,
213        ReduceOp.PRODUCT,
214        ReduceOp.BAND,
215        ReduceOp.BOR,
216        ReduceOp.BXOR,
217    ]
218    return reduceOp not in denyList
219
220
221class Backend(str):
222    """
223    An enum-like class for backends.
224
225    Available backends: GLOO, NCCL, UCC, MPI, and other registered backends.
226
227    The values of this class are lowercase strings, e.g., ``"gloo"``. They can
228    be accessed as attributes, e.g., ``Backend.NCCL``.
229
230    This class can be directly called to parse the string, e.g.,
231    ``Backend(backend_str)`` will check if ``backend_str`` is valid, and
232    return the parsed lowercase string if so. It also accepts uppercase strings,
233    e.g., ``Backend("GLOO")`` returns ``"gloo"``.
234
235    .. note:: The entry ``Backend.UNDEFINED`` is present but only used as
236              initial value of some fields. Users should neither use it directly
237              nor assume its existence.
238    """
239
240    UNDEFINED = "undefined"
241    GLOO = "gloo"
242    NCCL = "nccl"
243    UCC = "ucc"
244    MPI = "mpi"
245
246    _BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"])
247
248    _plugins: Dict[str, _BackendPlugin] = {}
249
250    backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI]
251
252    default_device_backend_map: Dict[str, str] = {
253        "cpu": GLOO,
254        "cuda": NCCL,
255    }
256
257    backend_capability: Dict[str, List[str]] = {
258        GLOO: ["cpu", "cuda"],
259        NCCL: ["cuda"],
260        UCC: ["cpu", "cuda"],
261        MPI: ["cpu", "cuda"],
262    }
263
264    backend_type_map: Dict[str, ProcessGroup.BackendType] = {
265        UNDEFINED: ProcessGroup.BackendType.UNDEFINED,
266        GLOO: ProcessGroup.BackendType.GLOO,
267        NCCL: ProcessGroup.BackendType.NCCL,
268        UCC: ProcessGroup.BackendType.UCC,
269    }
270
271    def __new__(cls, name: str):
272        """Create and return a new instance of the class."""
273        if not isinstance(name, str):
274            raise ValueError("Backend constructor parameter must be string-ish")
275        value = getattr(Backend, name.upper(), Backend.UNDEFINED)
276
277        if value == Backend.UNDEFINED:
278            value = name.lower()
279        return value
280
281    @classmethod
282    def register_backend(
283        cls,
284        name,
285        func,
286        extended_api=False,
287        devices: Optional[Union[str, List[str]]] = None,
288    ) -> None:
289        """
290        Register a new backend with the given name and instantiating function.
291
292        This class method is used by 3rd party ``ProcessGroup`` extension to
293        register new backends.
294
295        Args:
296            name (str): Backend name of the ``ProcessGroup`` extension. It
297                        should match the one in ``init_process_group()``.
298            func (function): Function handler that instantiates the backend.
299                             The function should be implemented in the backend
300                             extension and takes four arguments, including
301                             ``store``, ``rank``, ``world_size``, and ``timeout``.
302            extended_api (bool, optional): Whether the backend supports extended argument structure.
303                                           Default: ``False``. If set to ``True``, the backend
304                                           will get an instance of ``c10d::DistributedBackendOptions``, and
305                                           a process group options object as defined by the backend implementation.
306            device (str or list of str, optional): device type this backend
307                            supports, e.g. "cpu", "cuda", etc. If `None`,
308                            assuming both "cpu" and "cuda"
309
310        .. note:: This support of 3rd party backend is experimental and subject to change.
311
312        """
313        # Allow UCC plugin if Pytorch is not built with native support.
314        # TODO: remove this exception once UCC plugin is fully deprecated.
315        if name != Backend.UCC or (name == Backend.UCC and is_ucc_available()):
316            assert not hasattr(
317                Backend, name.upper()
318            ), f"{name.upper()} c10d backend already exist"
319        assert (
320            name.upper() not in Backend._plugins
321        ), f"{name.upper()} c10d backend creator function already exist"
322
323        setattr(Backend, name.upper(), name.lower())
324        Backend.backend_list.append(name.lower())
325        if devices is not None:
326            for device in devices:
327                if device != "cpu" and device != "cuda":
328                    Backend.default_device_backend_map[device] = name.lower()
329        Backend.backend_type_map[name.lower()] = ProcessGroup.BackendType.CUSTOM
330
331        # Update device capability matrix in Backend class
332        if devices is None:
333            # This is more of a backward support for groups like `threaded`:
334            # assume default devices "cpu" and "cuda", but warn
335            warnings.warn(
336                f"Device capability of {name} unspecified, assuming `cpu` and "
337                "`cuda`. Please specify it via the `devices` argument of "
338                "`register_backend`."
339            )
340            Backend.backend_capability[name.lower()] = ["cpu", "cuda"]
341        elif isinstance(devices, str):
342            # Single device string specified. Simply convert to list.
343            Backend.backend_capability[name.lower()] = [devices]
344        else:
345            Backend.backend_capability[name.lower()] = devices
346
347        Backend._plugins[name.upper()] = Backend._BackendPlugin(func, extended_api)
348
349
350class BackendConfig:
351    """Backend configuration class."""
352
353    def __init__(self, backend: Backend):
354        """Init."""
355        self.device_backend_map: Dict[str, Backend] = {}
356        backend = str(backend)
357
358        if backend == Backend.UNDEFINED:
359            # default config when backend is not specified
360            # supported since PyTorch 2.0
361            for device, default_backend in Backend.default_device_backend_map.items():
362                if is_backend_available(default_backend):
363                    if (
364                        default_backend == Backend.NCCL
365                        and not torch.cuda.is_available()
366                    ):
367                        continue
368                    self.device_backend_map[device] = Backend(default_backend)
369        elif backend.lower() in Backend.backend_list:
370            # Cases for when backend is a single string (without device types)
371            # e.g. "nccl", "gloo", "ucc", "mpi"
372            supported_devices = Backend.backend_capability[backend.lower()]
373            backend_val = Backend(backend)
374            self.device_backend_map = dict.fromkeys(supported_devices, backend_val)
375        elif ":" in backend.lower():
376            # Backend specified in "device:backend" format
377            # make sure the backend string is in the correct format
378            # "{device_type1}:{backend1},{device_type2}:{backend2}"
379            # e.g. "cpu:gloo,cuda:nccl"
380            backend_str_error_message = f"""The custom backend string argument is invalid: {backend}.
381                Custom backend string is an experimental feature where the backend string must be in the format:
382                "<device_type1>:<backend1>,<device_type2>:<backend2>...". e.g. 'cpu:gloo,cuda:nccl'"""
383
384            # parse the backend string and populate the device_backend_map
385            for device_backend_pair_str in backend.lower().split(","):
386                device_backend_pair = device_backend_pair_str.split(":")
387                if len(device_backend_pair) != 2:
388                    raise ValueError(
389                        f"Invalid device:backend pairing: \
390                                     {device_backend_pair_str}. {backend_str_error_message}"
391                    )
392                device, backend = device_backend_pair
393                if device in self.device_backend_map:
394                    raise ValueError(
395                        f"Duplicate device type {device} \
396                                     in backend string: {backend}. {backend_str_error_message}"
397                    )
398                self.device_backend_map[device] = Backend(backend)
399        else:
400            # User specified a single backend name whose device capability is
401            # unknown, assuming it can support the default devices of PyTorch
402            # (cpu and cuda)
403            warnings.warn(
404                f"Device capability of {backend} unknown, assuming `cpu` and "
405                "`cuda`. You can specify it in `device:backend` format in "
406                "`init_process_group` call."
407            )
408            backend_val = Backend(backend)
409            self.device_backend_map = {
410                "cpu": backend_val,
411                "cuda": backend_val,
412                "xpu": backend_val,
413            }
414
415        logger.info("Using backend config: %s", self.device_backend_map)
416
417    def __repr__(self):
418        """Return all the device:backend pairs separated by commas."""
419        return ",".join(
420            f"{device}:{backend}" for device, backend in self.device_backend_map.items()
421        )
422
423    def get_device_backend_map(self) -> Dict[str, Backend]:
424        """Return backend map of the device."""
425        return self.device_backend_map
426
427
428class _reduce_op:
429    r"""
430    Deprecated enum-like class.
431
432    For reduction operations: ``SUM``, ``PRODUCT``, ``MIN``, and ``MAX``.
433
434    :class:`~torch.distributed.ReduceOp` is recommended to use instead.
435    """
436
437    def __init__(self) -> None:
438        # __members__ is a dict storing key-value pairs for enum classes
439        for k, v in ReduceOp.RedOpType.__members__.items():
440            setattr(self, k, v)
441        self.__members__ = ReduceOp.RedOpType.__members__
442
443    @deprecated(
444        "`torch.distributed.reduce_op` is deprecated, "
445        "please use `torch.distributed.ReduceOp` instead",
446        category=FutureWarning,
447    )
448    def __getattribute__(self, key):
449        return object.__getattribute__(self, key)
450
451
452reduce_op = _reduce_op()
453
454
455class P2POp:
456    """
457    A class to build point-to-point operations for ``batch_isend_irecv``.
458
459    This class builds the type of P2P operation, communication buffer, peer rank,
460    Process Group, and tag. Instances of this class will be passed to
461    ``batch_isend_irecv`` for point-to-point communications.
462
463    Args:
464        op (Callable): A function to send data to or receive data from a peer process.
465            The type of ``op`` is either ``torch.distributed.isend`` or
466            ``torch.distributed.irecv``.
467        tensor (Tensor): Tensor to send or receive.
468        peer (int): Destination or source rank.
469        group (ProcessGroup, optional): The process group to work on. If None,
470            the default process group will be used.
471        tag (int, optional): Tag to match send with recv.
472    """
473
474    def __init__(
475        self,
476        op: Callable,
477        tensor: torch.Tensor,
478        peer: int,
479        group: Optional[ProcessGroup] = None,
480        tag: int = 0,
481    ):
482        """Init."""
483        self.op = op
484        self.tensor = tensor
485        self.peer = peer
486        self.group = group
487        self.tag = tag
488
489    def __new__(
490        cls,
491        op: Callable,
492        tensor: torch.Tensor,
493        peer: int,
494        group: Optional[ProcessGroup] = None,
495        tag: int = 0,
496    ):
497        """Create and return a new instance of the class."""
498        _check_op(op)
499        _check_single_tensor(tensor, "tensor")
500        return object.__new__(cls)
501
502    def __repr__(self):
503        my_group_rank = get_rank(self.group)
504        peer_group_rank = (
505            get_group_rank(self.group, self.peer) if self.group else self.peer
506        )
507        op_name = self.op.__name__
508        group_name = self.group.group_name if self.group else "default_pg"
509        if "send" in op_name:
510            s = my_group_rank
511            d = peer_group_rank
512        elif "recv" in op_name:
513            s = peer_group_rank
514            d = my_group_rank
515        else:
516            return super().__repr__()
517
518        return f"P2POp({op_name} pg={group_name}, s={s}, d={d},  {self.tensor.shape}, {self.tensor.dtype})"
519
520
521class _CollOp:
522    """
523    A class to capture collective operations.
524
525    Args:
526        op (Callable): A collective function, e.g. ``torch.distributed.all_reduce``.
527        tensor (Tensor): Tensor to operate on.
528        dst_tensor (Tensor, optional): Provided when source and destinaton tensors are not the same.
529        redop (ReduceOp, optional): reduce operation.
530        root (int, optional): root of broadcast or reduce.
531    """
532
533    def __init__(
534        self,
535        op: Callable,
536        tensor: torch.Tensor,
537        dst_tensor: Optional[torch.Tensor] = None,
538        redop: Optional[ReduceOp] = None,
539        root: Optional[int] = None,
540    ):
541        self.op = op
542        self.tensor = tensor
543        self.dst_tensor = dst_tensor
544        self.redop = redop
545        self.root = root
546
547
548# DO NOT USE THESE FIELDS DIRECTLY.
549# Use them through the _world object to make sure the _world override mechanism
550_pg_map: Dict[ProcessGroup, Tuple[str, Store]] = {}
551_pg_names: Dict[ProcessGroup, str] = {}
552_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}
553# For a pg, it is a map from ProcessGroup to BackendConfig
554_pg_backend_config: Dict[ProcessGroup, str] = {}
555_group_count = 0
556_tags_to_pg: Dict[str, List[ProcessGroup]] = {}
557_pg_to_tag: Dict[ProcessGroup, str] = {}
558_backend: Optional[str] = None
559
560
561class _World:
562    """
563    Container class for c10d process group state.
564
565    This is used during registration and lookup of PG state.
566
567    .. warning:: This is an experimental API intended to expose the inner workings
568       of c10d and is subject to change..
569    """
570
571    def __init__(self) -> None:
572        self._default_pg = None
573        self._pg_coalesce_state: Dict[ProcessGroup, List[_CollOp]] = {}
574        self._pg_default_device: Dict[ProcessGroup, torch.device] = {}
575
576    @property
577    def default_pg(self) -> Optional[ProcessGroup]:
578        """
579        Process group that includes all ranks of the cluster.
580
581        This default ProcessGroup is used by c10d APIs when a ProcessGroup is needed
582        but None is provided.
583        """
584        return self._default_pg
585
586    @default_pg.setter
587    def default_pg(self, value) -> None:
588        self._default_pg = value
589
590    @property
591    def pg_map(self) -> Dict[ProcessGroup, Tuple[str, Store]]:
592        """
593        Provide Mapping from ProcessGroup to backend name and store.
594
595        For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)
596        For MPI pg, it is a map from ProcessGroup to (Backend, None)
597
598        TODO don't expose the map, expose fine grained ops
599        """
600        global _pg_map
601        return _pg_map
602
603    @property
604    def pg_names(self) -> Dict[ProcessGroup, str]:
605        """
606        Process group's names, map from ProcessGroup to str.
607
608        TODO don't expose the map, expose fine grained ops
609        """
610        global _pg_names
611        return _pg_names
612
613    @property
614    def pg_group_ranks(self) -> Dict[ProcessGroup, Dict[int, int]]:
615        """
616        Process group's global rank to local rank mapping.
617
618        TODO don't expose the map, expose fine grained ops
619        """
620        global _pg_group_ranks
621        return _pg_group_ranks
622
623    @property
624    def pg_backend_config(self) -> Dict[ProcessGroup, str]:
625        """
626        Process group's backend config.
627
628        TODO don't expose the map, expose fine grained ops
629        """
630        global _pg_backend_config
631        return _pg_backend_config
632
633    @property
634    def group_count(self) -> int:
635        """
636        Process group count for default naming.
637
638        TODO don't expose group_count, use something else instead
639        """
640        global _group_count
641        return _group_count
642
643    @group_count.setter
644    def group_count(self, value: int) -> None:
645        """Use to compute the name of ProcessGroups when using global synchronization."""
646        global _group_count
647        _group_count = value
648
649    @property
650    def tags_to_pg(self) -> Dict[str, List[ProcessGroup]]:
651        global _tags_to_pg
652        return _tags_to_pg
653
654    @property
655    def pg_to_tag(self) -> Dict[ProcessGroup, str]:
656        global _pg_to_tag
657        return _pg_to_tag
658
659    @property
660    def pg_coalesce_state(self) -> Dict[ProcessGroup, List[_CollOp]]:
661        return self._pg_coalesce_state
662
663    @property
664    def pg_default_device(self) -> Dict[ProcessGroup, torch.device]:
665        return self._pg_default_device
666
667    @property
668    def pg_config_info(self) -> List[Dict[str, Any]]:
669        """
670        Return a list of dict with process groups and backends.
671
672        Along with their unique IDs and configurations (types and ranks).
673        """
674        config_info: List[Dict[str, Any]] = []
675        default_pg_size = _get_group_size(None)
676        for pg in self.pg_map.keys():
677            ranks = self.pg_group_ranks[pg]
678            config_info.append(
679                {
680                    "pg_name": self.pg_names[pg],
681                    "pg_desc": pg.group_desc,
682                    "backend_config": self.pg_backend_config[pg],
683                    "ranks": list(ranks.keys())
684                    if len(ranks) != default_pg_size
685                    else [],  # 'ranks' is an empty list when all ranks are involved in a pg
686                    "group_size": len(ranks),
687                    "group_count": self.group_count,
688                }
689            )
690        return config_info
691
692
693_world = _World()
694"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
695
696
697class _WorldMeta(type):
698    """
699    Meta class of ``group`` and ``GroupMember``.
700
701    Allows them to have the class property ``WORLD``.
702    """
703
704    # Points to the default PG once initialized.
705    @property
706    def WORLD(cls) -> Optional[ProcessGroup]:
707        return _world.default_pg
708
709    @WORLD.setter
710    def WORLD(cls, pg: Optional[ProcessGroup]):
711        _world.default_pg = pg
712
713
714class group(metaclass=_WorldMeta):
715    """Group class. Placeholder."""
716
717
718class GroupMember(metaclass=_WorldMeta):
719    """Group member class."""
720
721    NON_GROUP_MEMBER = -100
722
723
724def _get_default_timeout(backend: Backend) -> timedelta:
725    # see note on nccl vs other backend timeout (constants.py)
726    if backend == Backend.NCCL:
727        if not isinstance(default_pg_nccl_timeout, timedelta):
728            # TODO moco benchmark on CPU initializes pgnccl backend today, triggered this assert in CI before it was
729            # changed to be a warning.  We should fix the moco model.
730            warnings.warn(
731                "Attempted to get default timeout for nccl backend, but NCCL support is not compiled"
732            )
733            return default_pg_timeout
734        return default_pg_nccl_timeout
735    else:
736        return default_pg_timeout
737
738
739def _check_valid_timeout(timeout: Any) -> None:
740    if not isinstance(timeout, timedelta):
741        raise TypeError(
742            f"Expected timeout argument to be of type datetime.timedelta, got {timeout}"
743        )
744
745
746# Default process group state
747_default_pg_init_method: Optional[str] = None
748
749STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key"
750
751
752def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device:
753    """
754    Return the device to use with ``group`` for control flow usage (object collectives, barrier).
755
756    There are selection rules:
757        1. If user specifies exactly one backend in ``init_process_group`` call:
758            use that backend
759        2. Else if user specifies multiple "device:backend" pairs in init_process_group:
760            If "cpu" is among those pairs, use "cpu" (because the object is in cpu memory);
761            Otherwise, use the first backend (sort of a random pick).
762
763    Args:
764        group (ProcessGroup, optional): The process group to work on. If None,
765            the default process group will be used.
766
767    Returns:
768        torch.device: The device to use with ``group``.
769
770    """
771    group = group or _get_default_group()
772    if group in _world.pg_default_device:
773        # Previously searched and cached; just return
774        return _world.pg_default_device[group]
775
776    if not isinstance(group, ProcessGroup):
777        # Provide backward compatibility to cases where `group` passed in is
778        # actually a Backend (like `ProcessGroupGloo`) rather than a
779        # `ProcessGroup` in PT 2.0 sense
780        warnings.warn(
781            f"You are using a Backend {type(group)} as a ProcessGroup. "
782            "This usage is deprecated since PyTorch 2.0. Please use a public API "
783            "of PyTorch Distributed instead.",
784            FutureWarning,
785            stacklevel=3,
786        )
787        # Most users create Gloo with private API for object collectives
788        _world.pg_default_device[group] = torch.device("cpu")
789        return _world.pg_default_device[group]
790
791    """
792    ``group._device_types`` is a property pybind that returns the devices
793    ("cpu", "cuda", etc) supported by ``group``. Can be multiple if the
794    ``group`` supports multiple devices.
795    """
796    devices = group._device_types
797
798    if len(devices) == 1:
799        # User fixed exactly one backend in `init_process_group`
800        _world.pg_default_device[group] = devices[0]
801    elif len(devices) == 0:
802        # No backend has been registered with this PG (maybe because no
803        # collective has been run?) We pick cpu as the default and hopefully
804        # this would lazily init Gloo or other available cpu backend.
805        _world.pg_default_device[group] = torch.device("cpu")
806    elif torch.device("cpu") in devices:
807        # There are multiple backends in this PG and cpu is among them.
808        # cpu is preferred as the object is in cpu memory. No need for device
809        # copy.
810        _world.pg_default_device[group] = torch.device("cpu")
811    else:
812        # No cpu in the backend list. Randomly pick the first backend
813        _world.pg_default_device[group] = devices[0]
814
815    logger.info(
816        "Using device %s for object " "collectives.", _world.pg_default_device[group]
817    )
818    return _world.pg_default_device[group]
819
820
821@_time_logger
822def _store_based_barrier(
823    rank,
824    store,
825    group_name,
826    rendezvous_count,
827    timeout,
828    logging_interval=timedelta(seconds=10),
829) -> None:
830    """
831    Store based barrier for synchronizing processes.
832
833    Barrier based on store which is used for synchronizing processes after
834    ``init_process_group`` or ``new_group``. Intended to be used only with
835    those two methods and is not a generic alternative to ``barrier()``.
836    """
837    store_key = f"{STORE_BASED_BARRIER_PREFIX}:{group_name}"
838    store.add(store_key, 1)
839    logger.debug("Added key: %s to store for rank: %s", store_key, rank)
840
841    # Now wait for all workers to check in with the store.
842    world_size = rendezvous_count
843    worker_count = store.add(store_key, 0)
844
845    last_worker_key = f"{store_key}:last_worker"
846    if worker_count == world_size:
847        store.set(last_worker_key, "1")
848
849    # adjust the timeout to be at least 10secs + 1sec per thousand ranks to reduce the odds of timeout
850    # this value was empirically found while scale testing.
851    logging_interval = max(logging_interval, timedelta(seconds=10 + world_size / 1000))
852
853    start = time.time()
854    while True:
855        try:
856            # This will throw an exception after the logging_interval in which we print out
857            # the status of the group or time out officially, throwing runtime error
858            store.wait([last_worker_key], logging_interval)
859            break
860        except RuntimeError as e:
861            worker_count = store.add(store_key, 0)
862            # Print status periodically to keep track.
863            logger.debug(
864                "Waiting in store based barrier to initialize process group for "
865                "rank: %s, key: %s (world_size=%s, num_workers_joined=%s, timeout=%s error=%s)",
866                rank,
867                store_key,
868                world_size,
869                worker_count,
870                timeout,
871                e,
872            )
873
874            if timedelta(seconds=(time.time() - start)) > timeout:
875                raise DistStoreError(  # noqa: B904
876                    "Timed out initializing process group in store based barrier on "
877                    f"rank {rank}, for key: {store_key} (world_size={world_size}, "
878                    f"num_workers_joined={worker_count}, timeout={timeout} error={e})"
879                )
880
881    logger.info(
882        "Rank %s: Completed store-based barrier for key:%s with %s nodes.",
883        rank,
884        store_key,
885        world_size,
886    )
887
888
889def _rank_not_in_group(group: Optional[ProcessGroup]) -> bool:
890    """Check if the current process's rank is not in a given group."""
891    if group is None:
892        return False
893    return group == GroupMember.NON_GROUP_MEMBER
894
895
896def _warn_not_in_group(op_name) -> None:
897    global_rank = -1 if GroupMember.WORLD is None else GroupMember.WORLD.rank()
898    warnings.warn(
899        f"Running {op_name} on global rank {global_rank} which does not "
900        "belong to the given group."
901    )
902
903
904def get_group_rank(group: ProcessGroup, global_rank: int) -> int:
905    """
906    Translate a global rank into a group rank.
907
908    ``global_rank`` must be part of ``group`` otherwise this raises RuntimeError.
909
910    Args:
911        group (ProcessGroup): ProcessGroup to find the relative rank.
912        global_rank (int): Global rank to query.
913
914    Returns:
915        Group rank of ``global_rank`` relative to ``group``
916
917    N.B. calling this function on the default process group returns identity
918    """
919    if group is GroupMember.WORLD:
920        return global_rank
921    if group not in _world.pg_group_ranks:
922        raise ValueError(
923            f"Group {group} is not registered, please create group with torch.distributed.new_group API"
924        )
925    group_ranks = _world.pg_group_ranks[group]
926    if global_rank not in group_ranks:
927        raise ValueError(f"Global rank {global_rank} is not part of group {group}")
928
929    return group_ranks[global_rank]
930
931
932def get_global_rank(group: ProcessGroup, group_rank: int) -> int:
933    """
934    Translate a group rank into a global rank.
935
936    ``group_rank`` must be part of `group` otherwise this raises RuntimeError.
937
938    Args:
939        group (ProcessGroup): ProcessGroup to find the global rank from.
940        group_rank (int): Group rank to query.
941
942    Returns:
943        Global rank of ``group_rank`` relative to ``group``
944
945    N.B. calling this function on the default process group returns identity
946    """
947    if group is GroupMember.WORLD:
948        return group_rank
949    if group not in _world.pg_group_ranks:
950        raise ValueError(
951            f"Group {group} is not registered, please create group with torch.distributed.new_group API"
952        )
953    for rank, grp_rank in _world.pg_group_ranks[group].items():
954        if grp_rank == group_rank:
955            return rank
956    raise ValueError(f"Group rank {group_rank} is not part of group {group}")
957
958
959# TODO: remove this once the ecosystem moves away from it.
960@deprecated(
961    "`torch.distributed.distributed_c10d._get_global_rank` is deprecated, "
962    "please use `torch.distributed.distributed_c10d.get_global_rank` instead",
963    category=FutureWarning,
964)
965def _get_global_rank(group, rank) -> int:
966    """Use get_global_rank as this method is deprecated."""
967    return get_global_rank(group, rank)
968
969
970def get_process_group_ranks(group: ProcessGroup) -> List[int]:
971    """
972    Get all ranks associated with ``group``.
973
974    Args:
975        group (ProcessGroup): ProcessGroup to get all ranks from.
976
977    Returns:
978        List of global ranks ordered by group rank.
979    """
980    return list(_world.pg_group_ranks[group].keys())
981
982
983def _get_group_size(group) -> int:
984    """Get a given group's world size."""
985    if group is GroupMember.WORLD or group is None:
986        default_pg = _get_default_group()
987        return default_pg.size()
988    return group.size()
989
990
991def _get_group_size_by_name(group_name: str) -> int:
992    group = _resolve_process_group(group_name)
993    return group.size()
994
995
996def _resolve_group_name_by_ranks_and_tag(ranks: List[int], tag: str) -> str:
997    # TODO(yifu): remove this function once ranks + tag is not a supported
998    # identifier for process group for functional collectives.
999    group = _find_pg_by_ranks_and_tag(tag, ranks)
1000    if group is None:
1001        raise ValueError("")
1002    return group.group_name
1003
1004
1005def _check_single_tensor(param, param_name) -> None:
1006    """Check that the parameter ``param_name`` is a single tensor."""
1007    if not isinstance(param, torch.Tensor):
1008        raise TypeError(
1009            f"""Invalid function argument. Expected parameter `{param_name}` of type torch.Tensor
1010             but got {type(param)} instead."""
1011        )
1012
1013
1014def _check_tensor_list(param, param_name) -> None:
1015    """Check that the parameter ``param_name`` is a list of tensors."""
1016    if not isinstance(param, list):
1017        raise TypeError(
1018            f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor]
1019             but got {type(param)} instead."""
1020        )
1021    elif not all(isinstance(p, torch.Tensor) for p in param):
1022        raise TypeError(
1023            f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor]
1024             but got {type(param)} with elements of type {[type(p) for p in param]}."""
1025        )
1026
1027
1028def _as_iterable(obj) -> collections.abc.Iterable:
1029    return obj if isinstance(obj, list) else (obj,)
1030
1031
1032def _ensure_all_tensors_same_dtype(*tensors) -> None:
1033    last_dtype = None
1034    for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)):
1035        tensor_dtype = tensor.dtype
1036        # Mixing complex and its element type is allowed
1037        if tensor_dtype.is_complex:
1038            tensor_dtype = (
1039                torch.float32 if tensor_dtype == torch.complex64 else torch.complex128
1040            )
1041
1042        if last_dtype is None:
1043            last_dtype = tensor_dtype
1044        else:
1045            if last_dtype != tensor_dtype:
1046                raise ValueError(
1047                    "Invalid usage of tensors with different dtypes"
1048                    f"Found {last_dtype} and  {tensor.dtype}"
1049                )
1050
1051
1052def _check_op(op) -> None:
1053    """Check that the ``op`` is either isend or irecv."""
1054    if op not in [isend, irecv]:
1055        raise ValueError(
1056            "Invalid ``op``. Expected ``op`` "
1057            "to be of type ``torch.distributed.isend`` or "
1058            "``torch.distributed.irecv``."
1059        )
1060
1061
1062def _check_p2p_op_list(p2p_op_list) -> None:
1063    """
1064    Check that the ``p2p_op_list`` is a list of P2POp instances.
1065
1066    Also, check that all ops use the same group.
1067    """
1068    if not isinstance(p2p_op_list, list) or not all(
1069        isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list
1070    ):
1071        raise ValueError(
1072            "Invalid ``p2p_op_list``. Each op is expected to "
1073            "to be of type ``torch.distributed.P2POp``."
1074        )
1075
1076    group = p2p_op_list[0].group
1077    if not all(group == p2p_op.group for p2p_op in p2p_op_list):
1078        raise ValueError("All ops need to use the same group.")
1079
1080
1081def is_mpi_available() -> bool:
1082    """Check if the MPI backend is available."""
1083    return _MPI_AVAILABLE
1084
1085
1086def is_nccl_available() -> bool:
1087    """Check if the NCCL backend is available."""
1088    return _NCCL_AVAILABLE
1089
1090
1091def is_gloo_available() -> bool:
1092    """Check if the Gloo backend is available."""
1093    return _GLOO_AVAILABLE
1094
1095
1096def is_ucc_available() -> bool:
1097    """Check if the UCC backend is available."""
1098    return _UCC_AVAILABLE
1099
1100
1101def is_backend_available(backend: str) -> bool:
1102    """
1103    Check backend availability.
1104
1105    Checks if the given backend is available and supports the built-in backends or
1106    third-party backends through function ``Backend.register_backend``.
1107
1108    Args:
1109        backend (str): Backend name.
1110    Returns:
1111        bool: Returns true if the backend is available otherwise false.
1112    """
1113    # If the backend has an ``is_backend_available`` function, return the result of that function directly
1114    available_func = getattr(torch.distributed, f"is_{backend.lower()}_available", None)
1115    if available_func:
1116        return available_func()
1117
1118    return backend.lower() in Backend.backend_list
1119
1120
1121def is_initialized() -> bool:
1122    """Check if the default process group has been initialized."""
1123    return GroupMember.WORLD is not None
1124
1125
1126def is_torchelastic_launched() -> bool:
1127    """
1128    Check whether this process was launched with ``torch.distributed.elastic`` (aka torchelastic).
1129
1130    The existence of ``TORCHELASTIC_RUN_ID`` environment
1131    variable is used as a proxy to determine whether the current process
1132    was launched with torchelastic. This is a reasonable proxy since
1133    ``TORCHELASTIC_RUN_ID`` maps to the rendezvous id which is always a
1134    non-null value indicating the job id for peer discovery purposes..
1135    """
1136    return os.getenv("TORCHELASTIC_RUN_ID") is not None
1137
1138
1139def _is_barrier_after_init() -> int:
1140    # Environment variable to control whether process group should perform a
1141    # barrier after its init. Default value is 0, i.e. no barrier. If you
1142    # experience issue with this setting, you may set
1143    # `TORCH_DIST_INIT_BARRIER=1` to add the barrier.
1144    return int(os.getenv("TORCH_DIST_INIT_BARRIER", "0"))
1145
1146
1147def _get_default_group() -> ProcessGroup:
1148    """Get the default process group created by init_process_group."""
1149    if not is_initialized():
1150        raise ValueError(
1151            "Default process group has not been initialized, "
1152            "please make sure to call init_process_group."
1153        )
1154    if TYPE_CHECKING:
1155        return not_none(GroupMember.WORLD)
1156    else:
1157        return GroupMember.WORLD
1158
1159
1160def _get_default_store() -> Store:
1161    """Get the default store created by init_process_group."""
1162    if not is_initialized():
1163        raise ValueError(
1164            "Default process group has not been initialized, "
1165            "please make sure to call init_process_group."
1166        )
1167    default_pg = _get_default_group()
1168    _, default_store = _world.pg_map[default_pg]
1169    return default_store
1170
1171
1172def _update_default_pg(pg) -> None:
1173    _world.default_pg = pg
1174    rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1
1175    torch._C._distributed_c10d._set_global_rank(rank)
1176
1177
1178def get_backend_config(group: Optional[ProcessGroup] = None) -> str:
1179    """
1180    Return the backend configuration of the given process group.
1181
1182    Args:
1183        group (ProcessGroup, optional): The process group to work on. The
1184            default is the general main process group. If another specific group
1185            is specified, the calling process must be part of :attr:`group`.
1186
1187    Returns:
1188        The backend configuration of the given process group as a lower case string.
1189
1190    """
1191    pg = group or _get_default_group()
1192    if _rank_not_in_group(pg):
1193        raise ValueError("Invalid process group specified")
1194    backend_config = _world.pg_backend_config.get(pg)
1195    return str(not_none(backend_config))
1196
1197
1198def get_backend(group: Optional[ProcessGroup] = None) -> Backend:
1199    """
1200    Return the backend of the given process group.
1201
1202    Args:
1203        group (ProcessGroup, optional): The process group to work on. The
1204            default is the general main process group. If another specific group
1205            is specified, the calling process must be part of :attr:`group`.
1206
1207    Returns:
1208        The backend of the given process group as a lower case string.
1209
1210    """
1211    pg = group or _get_default_group()
1212    if _rank_not_in_group(pg):
1213        raise ValueError("Invalid process group specified")
1214    pg_store = _world.pg_map[pg] if pg in _world.pg_map else None
1215    return Backend(not_none(pg_store)[0])
1216
1217
1218def _get_process_group_uid(pg: ProcessGroup) -> int:
1219    backend = None
1220    try:
1221        backend = pg._get_backend(torch.device("cuda"))
1222    except RuntimeError:
1223        pass
1224    if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
1225        return backend.uid
1226    return -1
1227
1228
1229def _get_pg_config(group: Optional[ProcessGroup] = None) -> Dict[str, Any]:
1230    """
1231    Return the pg configuration of the given process group.
1232
1233    """
1234    pg = group or _get_default_group()
1235    return {
1236        "pg_name": _get_process_group_name(pg),
1237        "pg_desc": pg.group_desc,
1238        "backend_config": get_backend_config(pg),
1239        "pg_size": _get_group_size(pg),
1240        "ranks": get_process_group_ranks(pg),
1241    }
1242
1243
1244def _get_all_pg_configs() -> List[Dict[str, Any]]:
1245    """
1246    Return the pg configuration of all the process groups.
1247
1248    """
1249    config_info: List[Dict[str, Any]] = []
1250    for pg in _world.pg_map.keys():
1251        config_info.append(_get_pg_config(pg))
1252    return config_info
1253
1254
1255def get_pg_count() -> int:
1256    """
1257    Return the number of process groups.
1258
1259    """
1260    return _world.group_count
1261
1262
1263def get_node_local_rank(fallback_rank: Optional[int] = None) -> int:
1264    """
1265    Return the local rank of the current process relative to the node.
1266
1267    Semantically, this is a useful concept for mapping processes to devices.
1268    For example, on a node with 8 accelerator you could use the node local rank to decide
1269    which accelerator device to bind the process to.
1270
1271    In practice, the actual assignment of node local ranks is handled by the process launcher outside of pytorch,
1272    and communicated via the `LOCAL_RANK` environment variable.
1273
1274    Torchrun will automatically populate `LOCAL_RANK`, but other launchers may not.  If `LOCAL_RANK` is unspecified,
1275    this API will fall back to the provided kwarg 'fallback_rank' if specified, otherwise it will raise an error. The
1276    intent is to allow writing an application that runs either in single or multi device contexts without error.
1277
1278    """
1279    if "LOCAL_RANK" in os.environ:
1280        return int(os.environ["LOCAL_RANK"])
1281    elif fallback_rank is not None:
1282        return int(fallback_rank)
1283    raise RuntimeError(
1284        "LOCAL_RANK is not in the environment. Consider passing fallback_rank to allow `get_node_local_rank` to work, "
1285        "assuming you are not running in a multi-device context and want the code to run locally instead."
1286    )
1287
1288
1289def _add_ephemeral_timeout_for_all_pgs(timeout: timedelta) -> None:
1290    """
1291    This API adds an ephemeral timeout extension for all PGs locally
1292    on one rank. The timeout gets reset when the first collective issued
1293    after API called finished.
1294    NOTE: We only support to set timeout for cuda backends for now.
1295    NOTE: While this feature
1296    provides flexibility in specific scenarios, it introduces statefulness
1297    to timeout setting. Therefore, it is advisable to use this API sparingly
1298    and consider alternative approaches, such as directly setting the timeout
1299    or utilizing a barrier collective (one can set any timeout to the barrier),
1300    whenever feasible.
1301
1302    Args:
1303        timeout (timedelta): The delta of timeout to extend.
1304
1305    Returns:
1306        None.
1307    """
1308    for pg in _world.pg_map.keys():
1309        devices = pg._device_types
1310        if torch.device("cuda") in devices:
1311            backend = pg._get_backend(torch.device("cuda"))
1312            if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
1313                backend._add_ephemeral_timeout(timeout)
1314
1315
1316def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> None:
1317    """
1318    Set the timeout for the given process group when users want to use a different timeout instead of
1319    default values.
1320
1321    Args:
1322        timeout (timedelta): Timeout for operations executed against the process group which
1323            users want to set. Default value is 10 minutes for NCCL and 30 minutes for other backends.
1324            This is the duration after which collectives will be aborted asynchronously and the process will crash.
1325            This is done since CUDA execution is async and it is no longer safe to continue executing user code since
1326            failed async NCCL operations might result in subsequent CUDA operations running on corrupted data.
1327            When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout.
1328
1329        group (ProcessGroup, optional): The process group to work on. The
1330            default is the general main process group. If another specific group
1331            is specified, the calling process must be part of :attr:`group`.
1332
1333    Returns:
1334        None
1335    """
1336    if group is None:
1337        group = _get_default_group()
1338    if _rank_not_in_group(group):
1339        raise ValueError("Invalid process group specified")
1340    assert isinstance(group, ProcessGroup)
1341    devices = group._device_types
1342    backends = set()
1343    if torch.device("cpu") in devices and is_gloo_available():
1344        backend = group._get_backend(torch.device("cpu"))
1345        if isinstance(backend, ProcessGroupGloo):
1346            backends.add(backend)
1347    if torch.device("cuda") in devices:
1348        backend = group._get_backend(torch.device("cuda"))
1349        if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
1350            backends.add(backend)  # type: ignore[arg-type]
1351        elif is_gloo_available() and isinstance(backend, ProcessGroupGloo):
1352            backends.add(backend)  # type: ignore[arg-type]
1353    if len(backends) == 0:
1354        warnings.warn("Set timeout is now only supported for either nccl or gloo.")
1355    for backend in backends:
1356        backend._set_default_timeout(timeout)
1357
1358
1359@_exception_logger
1360@_time_logger
1361def init_process_group(
1362    backend: Optional[str] = None,
1363    init_method: Optional[str] = None,
1364    timeout: Optional[timedelta] = None,
1365    world_size: int = -1,
1366    rank: int = -1,
1367    store: Optional[Store] = None,
1368    group_name: str = "",
1369    pg_options: Optional[Any] = None,
1370    device_id: Optional[torch.device] = None,
1371) -> None:
1372    """
1373    Initialize the default distributed process group.
1374
1375    This will also initialize the distributed package.
1376
1377    There are 2 main ways to initialize a process group:
1378        1. Specify ``store``, ``rank``, and ``world_size`` explicitly.
1379        2. Specify ``init_method`` (a URL string) which indicates where/how
1380           to discover peers. Optionally specify ``rank`` and ``world_size``,
1381           or encode all required parameters in the URL and omit them.
1382
1383    If neither is specified, ``init_method`` is assumed to be "env://".
1384
1385
1386    Args:
1387        backend (str or Backend, optional): The backend to use. Depending on
1388            build-time configurations, valid values include ``mpi``, ``gloo``,
1389            ``nccl``, and ``ucc``. If the backend is not provided, then both a ``gloo``
1390            and ``nccl`` backend will be created, see notes below for how multiple
1391            backends are managed. This field can be given as a lowercase string
1392            (e.g., ``"gloo"``), which can also be accessed via
1393            :class:`Backend` attributes (e.g., ``Backend.GLOO``). If using
1394            multiple processes per machine with ``nccl`` backend, each process
1395            must have exclusive access to every GPU it uses, as sharing GPUs
1396            between processes can result in deadlocks. ``ucc`` backend is
1397            experimental.
1398        init_method (str, optional): URL specifying how to initialize the
1399                                     process group. Default is "env://" if no
1400                                     ``init_method`` or ``store`` is specified.
1401                                     Mutually exclusive with ``store``.
1402        world_size (int, optional): Number of processes participating in
1403                                    the job. Required if ``store`` is specified.
1404        rank (int, optional): Rank of the current process (it should be a
1405                              number between 0 and ``world_size``-1).
1406                              Required if ``store`` is specified.
1407        store(Store, optional): Key/value store accessible to all workers, used
1408                                to exchange connection/address information.
1409                                Mutually exclusive with ``init_method``.
1410        timeout (timedelta, optional): Timeout for operations executed against
1411            the process group. Default value is 10 minutes for NCCL and 30 minutes for other backends.
1412            This is the duration after which collectives will be aborted asynchronously and the process will crash.
1413            This is done since CUDA execution is async and it is no longer safe to continue executing user code since
1414            failed async NCCL operations might result in subsequent CUDA operations running on corrupted data.
1415            When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout.
1416
1417        group_name (str, optional, deprecated): Group name. This argument is ignored
1418        pg_options (ProcessGroupOptions, optional): process group options
1419            specifying what additional options need to be passed in during
1420            the construction of specific process groups. As of now, the only
1421            options we support is ``ProcessGroupNCCL.Options`` for the ``nccl``
1422            backend, ``is_high_priority_stream`` can be specified so that
1423            the nccl backend can pick up high priority cuda streams when
1424            there're compute kernels waiting. For other availble options to config nccl,
1425            See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
1426        device_id (torch.device, optional): a single, specific device
1427            to "bind" this process to, allowing for backend-specific
1428            optimizations.  Currently this has two effects, only under
1429            NCCL: the communicator is immediately formed (calling
1430            ``ncclCommInit*`` immediately rather than the normal lazy
1431            call) and sub-groups will use ``ncclCommSplit`` when
1432            possible to avoid unnecessary overhead of group creation. If you
1433            want to know NCCL initialization error early, you can also use this
1434            field.
1435
1436    .. note:: To enable ``backend == Backend.MPI``, PyTorch needs to be built from source
1437        on a system that supports MPI.
1438
1439    .. note:: Support for multiple backends is experimental. Currently when no backend is
1440        specified, both ``gloo`` and ``nccl`` backends will be created. The ``gloo`` backend
1441        will be used for collectives with CPU tensors and the ``nccl`` backend will be used
1442        for collectives with CUDA tensors. A custom backend can be specified by passing in
1443        a string with format "<device_type>:<backend_name>,<device_type>:<backend_name>", e.g.
1444        "cpu:gloo,cuda:custom_backend".
1445
1446    """
1447
1448    global _world
1449
1450    global _backend
1451    global _default_pg_init_method
1452
1453    if GroupMember.WORLD is not None:
1454        raise ValueError("trying to initialize the default process group twice!")
1455
1456    set_pytorch_distributed_envs_from_justknobs()
1457
1458    # Depending on the import order, some trace_rules functions may be evaluated
1459    # during the import phase. In such a case, these functions may not correctly
1460    # add the distributed related rules due to import circular dependency.
1461    # We need to clear the lru_cache during the runtime to ensure the correctness
1462    # of these trace_rules.
1463    #
1464    # Since this API must be called before all distributed code being compiled,
1465    # clearing the cache here should be safe.
1466    if "torch._dynamo" in sys.modules:
1467        torch._dynamo.trace_rules.clear_lru_cache()
1468
1469    assert (store is None) or (
1470        init_method is None
1471    ), "Cannot specify both init_method and store."
1472
1473    if store is not None:
1474        assert world_size > 0, "world_size must be positive if using store"
1475        assert rank >= 0, "rank must be non-negative if using store"
1476    elif init_method is None:
1477        init_method = "env://"
1478
1479    if backend:
1480        backend = Backend(backend)
1481    else:
1482        backend = Backend("undefined")
1483
1484    if timeout is None:
1485        timeout = _get_default_timeout(backend)
1486
1487    _check_valid_timeout(timeout)
1488
1489    """
1490    Group name is not visible to users unless they access
1491    internals of c10d. This means we can ignore the value
1492    they provide as it not exposed in a public way.
1493    """
1494    group_name = _process_group_name([], use_hashed_name=False)
1495    if backend == Backend.MPI:
1496        if world_size != -1 or rank != -1:
1497            warnings.warn(
1498                f"For MPI backend, world_size ({world_size}) and rank ({rank}) "
1499                "are ignored since they are assigned by the "
1500                "MPI runtime."
1501            )
1502
1503        default_pg, _ = _new_process_group_helper(
1504            -1,
1505            -1,
1506            [],
1507            backend,
1508            None,
1509            group_name,
1510            timeout=timeout,
1511            group_desc="default_pg",
1512        )
1513        _update_default_pg(default_pg)
1514    else:
1515        # backward compatible API
1516        if store is None:
1517            rendezvous_iterator = rendezvous(
1518                not_none(init_method), rank, world_size, timeout=timeout
1519            )
1520            store, rank, world_size = next(rendezvous_iterator)
1521            store.set_timeout(timeout)
1522
1523            # Use a PrefixStore to avoid accidental overrides of keys used by
1524            # different systems (e.g. RPC) in case the store is multi-tenant.
1525            store = PrefixStore("default_pg", store)
1526
1527        default_pg, _ = _new_process_group_helper(
1528            world_size,
1529            rank,
1530            [],
1531            backend,
1532            store,
1533            group_name,
1534            pg_options=pg_options,
1535            timeout=timeout,
1536            device_id=device_id,
1537            group_desc="default_pg",
1538        )
1539        _update_default_pg(default_pg)
1540
1541    _world.pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())}  # type: ignore[attr-defined, index]
1542    _backend = _world.pg_map[not_none(GroupMember.WORLD)][0]
1543    _default_pg_init_method = init_method
1544
1545    old_hook = sys.excepthook
1546    excepthook_prefix = f"[rank{get_rank()}]"
1547
1548    def _distributed_excepthook(*args):
1549        old_stderr = sys.stderr
1550        sys.stderr = buf = io.StringIO()
1551        try:
1552            old_hook(*args)
1553        finally:
1554            sys.stderr = old_stderr
1555        msg = buf.getvalue()
1556        msg = "\n".join(
1557            f"{excepthook_prefix}: {s}" if s != "" else "" for s in msg.split("\n")
1558        )
1559        sys.stderr.write(msg)
1560        sys.stderr.flush()
1561
1562    sys.excepthook = _distributed_excepthook
1563
1564    if _is_barrier_after_init() == 1:
1565        # barrier at the end to ensure that once we return from this method, all
1566        # process groups including global variables (if any) are updated
1567        # correctly on all ranks.
1568        # Update 04/2023: for large-scale runs, this barrier (esp. store-based
1569        # barrier) may be costly and/or unscalable. Also, in a lot of cases,
1570        # these barriers may be unnecessary, as proven by a green CI after
1571        # removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been
1572        # added which enables this barrier only when set to 1.
1573        logger.debug(
1574            "Performing barrier after ProcessGroup initialization since "
1575            "TORCH_DIST_INIT_BARRIER = 1"
1576        )
1577        if backend == Backend.MPI:
1578            # MPI backend doesn't use store.
1579            barrier()
1580        else:
1581            # Use store based barrier here since barrier() used a bunch of
1582            # default devices and messes up NCCL internal state.
1583            _store_based_barrier(rank, store, group_name, world_size, timeout)
1584
1585
1586def _get_split_source(pg):
1587    split_from = None
1588    if pg.bound_device_id:
1589        split_from = pg._get_backend(pg.bound_device_id)
1590    elif pg is _world.default_pg:
1591        try:
1592            split_from = pg._get_backend(torch.device("cuda"))
1593        except RuntimeError:
1594            # no cuda device associated with this backend
1595            pass
1596
1597    if not split_from or not split_from.supports_splitting:
1598        return None
1599
1600    # If necessary, find a backend to split from by peeling process
1601    # group wrappers from our potentially wrapped process group.
1602    while _GLOO_AVAILABLE and isinstance(split_from, _ProcessGroupWrapper):
1603        split_from = split_from.wrapped_pg
1604
1605    return split_from
1606
1607
1608def _shutdown_backend(pg):
1609    """
1610    Try to shut down the backend of a process group.
1611    Currently, only ProcessGroupNCCL backend is supported.
1612    No op for other backends.
1613    """
1614    backend = None
1615    try:
1616        backend = pg._get_backend(torch.device("cuda"))
1617    except RuntimeError:
1618        pass
1619    if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
1620        # explictly call shutdown to ensure that NCCL resources are released
1621        backend._shutdown()
1622
1623
1624def _new_process_group_helper(
1625    group_size,
1626    group_rank,
1627    global_ranks_in_group,
1628    backend,
1629    store,
1630    group_name,
1631    pg_options=None,
1632    timeout=None,
1633    pg_tag=None,
1634    device_id=None,
1635    group_desc=None,
1636):
1637    """
1638    Create a new distributed process group.
1639
1640    This function must be called by ALL processes in the global group, even if
1641    the calling process is not part of the newly created group. In that case,
1642    this function returns GroupMember.NON_GROUP_MEMBER.
1643
1644    This function is called with ``global_ranks_in_group == []`` for the default group.
1645    """
1646    global _world
1647
1648    if group_name in _world.pg_names.values():
1649        raise ValueError(
1650            "The specified group name has already been "
1651            "created, please use a different group name"
1652        )
1653
1654    if device_id is not None and (device_id.index is None or device_id.type != "cuda"):
1655        raise ValueError(
1656            "init_process_group device_id parameter must be a cuda device with an "
1657            "id, e.g. cuda:0, not just cuda or cpu"
1658        )
1659
1660    # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value
1661    _check_valid_timeout(timeout)
1662
1663    if pg_tag not in [None, ""]:
1664        # creating with the same tag and rank set results in the same underlying PG
1665        existing_group = _find_pg_by_ranks_and_tag(pg_tag, global_ranks_in_group)
1666        if existing_group:
1667            _, prefix_store = _world.pg_map[existing_group]
1668            return existing_group, prefix_store
1669
1670    group_desc = "undefined" if group_desc is None else group_desc
1671
1672    # The list of group ranks is empty if we're creating the default group.
1673    is_default_group = len(global_ranks_in_group) == 0
1674
1675    # nccl and potentially other backends allow creation of
1676    # communicators based on pre-existing ones, which can save
1677    # initialization time.  Due to lazy initialization of
1678    # communicators in some backends, we have to be careful and only
1679    # split when we *know* the backends already are connected _on all
1680    # ranks_.  We can only know this if the group we are making is the
1681    # entire world or if we have bound a device id to the world (which
1682    # causes early connection initialization).
1683    if is_initialized() and (
1684        len(global_ranks_in_group) == _get_default_group().size()
1685        or _get_default_group().bound_device_id
1686    ):
1687        split_from = _get_split_source(_get_default_group())
1688    else:
1689        split_from = None
1690
1691    # If this is a subgroup (which means group_ranks is specified),
1692    # we check if the current process is a member of the new group.
1693    if not is_default_group:
1694        global_rank = _get_default_group().rank()
1695        if global_rank not in global_ranks_in_group:
1696            # If we are using `ncclCommSplit` (or similar split from
1697            # other APIs) to create the communicator, we will need to
1698            # call `ncclCommSplit` on *all* ranks in this new group's
1699            # parent group, even those not in the new group.  This is
1700            # a requirement of the NCCL API as otherwise we would get
1701            # out of sync.
1702            if split_from:
1703                split_from.perform_nocolor_split(_get_default_group().bound_device_id)
1704            return GroupMember.NON_GROUP_MEMBER, None
1705
1706    prefix_store = PrefixStore(f"{group_name}/", store)
1707    base_pg_options = ProcessGroup.Options(backend=str(backend))
1708    base_pg_options._timeout = timeout
1709    pg: ProcessGroup = ProcessGroup(
1710        prefix_store, group_rank, group_size, base_pg_options
1711    )
1712    if device_id:
1713        pg.bound_device_id = device_id
1714    backend_config = BackendConfig(backend)
1715    backend_class: torch._C._distributed_c10d.Backend
1716    for device, backend_str in backend_config.get_device_backend_map().items():
1717        # Use the group name as prefix in the default store, such that
1718        # a single store can be reused by multiple groups.
1719        backend_prefix_store = PrefixStore(f"{device}/", prefix_store)
1720
1721        if backend_str == Backend.MPI:
1722            if not is_mpi_available():
1723                raise RuntimeError(
1724                    "Distributed package doesn't have MPI built in."
1725                    " MPI is only included if you build PyTorch from"
1726                    " source on a host that has MPI installed."
1727                )
1728            backend_class = ProcessGroupMPI.create(global_ranks_in_group)
1729            backend_type = ProcessGroup.BackendType.MPI
1730            if not backend_class:
1731                return GroupMember.NON_GROUP_MEMBER, None
1732            # create new process group with accurate rank and size
1733            if pg.rank() == -1 and pg.size() == -1:
1734                pg = ProcessGroup(
1735                    backend_prefix_store,
1736                    backend_class.rank(),
1737                    backend_class.size(),
1738                    base_pg_options,
1739                )
1740        elif backend_str == Backend.GLOO:
1741            # TODO: remove this check after lazy initialization is supported
1742            # if pg_options is not None:
1743            #     raise RuntimeError("GLOO options not supported")
1744            backend_class = ProcessGroupGloo(
1745                backend_prefix_store, group_rank, group_size, timeout=timeout
1746            )
1747            backend_type = ProcessGroup.BackendType.GLOO
1748        elif backend_str == Backend.NCCL:
1749            if not is_nccl_available():
1750                raise RuntimeError("Distributed package doesn't have NCCL built in")
1751            if pg_options is not None:
1752                assert isinstance(
1753                    pg_options, ProcessGroupNCCL.Options
1754                ), "Expected pg_options argument to be of type ProcessGroupNCCL.Options"
1755                if pg_options._timeout != timeout:
1756                    warnings.warn(
1757                        "pg_options._timeout was specified, "
1758                        "but timeout kwarg has a default value that will always override it. "
1759                    )
1760            else:
1761                # default pg_options for NCCL
1762                pg_options = ProcessGroupNCCL.Options()
1763                pg_options.is_high_priority_stream = False
1764            pg_options._timeout = timeout
1765
1766            if split_from:
1767                pg_options.split_from = split_from
1768                pg_options.split_color = _process_group_color(global_ranks_in_group)
1769            pg_options.global_ranks_in_group = global_ranks_in_group
1770            pg_options.group_name = group_name
1771            backend_class = ProcessGroupNCCL(
1772                backend_prefix_store, group_rank, group_size, pg_options
1773            )
1774            backend_type = ProcessGroup.BackendType.NCCL
1775        elif backend_str == Backend.UCC and is_ucc_available():
1776            # TODO: once UCC plugin is fully deprecated, remove
1777            # is_ucc_available() from above elif-condition and raise
1778            # RuntimeError if is_ucc_available() returns false.
1779
1780            backend_class = ProcessGroupUCC(
1781                backend_prefix_store, group_rank, group_size, timeout=timeout
1782            )
1783            backend_type = ProcessGroup.BackendType.UCC
1784        else:
1785            assert (
1786                backend_str.upper() in Backend._plugins
1787            ), f"Unknown c10d backend type {backend_str.upper()}"
1788
1789            backend_plugin = Backend._plugins[backend_str.upper()]
1790            creator_fn = backend_plugin.creator_fn
1791            extended_api = backend_plugin.extended_api
1792            backend_type = ProcessGroup.BackendType.CUSTOM
1793
1794            if not extended_api:
1795                backend_class = creator_fn(
1796                    backend_prefix_store, group_rank, group_size, timeout
1797                )
1798            else:
1799                dist_backend_opts = _DistributedBackendOptions()
1800                dist_backend_opts.store = backend_prefix_store
1801                dist_backend_opts.group_rank = group_rank
1802                dist_backend_opts.group_size = group_size
1803                dist_backend_opts.timeout = timeout
1804                dist_backend_opts.group_id = group_name
1805                dist_backend_opts.global_ranks_in_group = global_ranks_in_group
1806
1807                backend_class = creator_fn(dist_backend_opts, pg_options)
1808
1809        # Set sequence numbers for gloo and nccl backends.
1810        if backend_str == Backend.GLOO:
1811            assert isinstance(backend_class, ProcessGroupGloo)
1812            backend_class._set_sequence_number_for_group()
1813        elif backend_str == Backend.NCCL:
1814            assert isinstance(backend_class, ProcessGroupNCCL)
1815            backend_class._set_sequence_number_for_group()
1816
1817        # If the type is a subclass of ProcessGroup then return this process group immediately
1818        # TODO: This defaults to the old behavior for PythonProcessGroups which overwrites the
1819        # ProcessGroup instance
1820        if issubclass(type(backend_class), ProcessGroup):
1821            pg = backend_class  # type: ignore[assignment]
1822            break
1823
1824        # Process group wrapper initialization for supported PGs when TORCH_DISTRIBUTED_DEBUG is set
1825        if (
1826            backend_str in [Backend.GLOO, Backend.NCCL, Backend.UCC]
1827            or backend_str.upper() in Backend._plugins
1828        ):
1829            # In debug mode and if GLOO is available, wrap in a wrapper PG that
1830            # enables enhanced collective checking for debuggability.
1831            if get_debug_level() == DebugLevel.DETAIL:
1832                if not _GLOO_AVAILABLE:
1833                    logger.info(
1834                        """TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but
1835                                GLOO is not available. Build with Gloo to
1836                                create a wrapper process group in debug mode
1837                                to aid collective desynchronization debugging."""
1838                    )
1839                else:
1840                    backend_class = _create_process_group_wrapper(
1841                        wrapped_pg=backend_class,
1842                        store_prefix=group_name,
1843                        store=backend_prefix_store,
1844                        rank=group_rank,
1845                        world_size=group_size,
1846                        timeout=timeout,
1847                    )
1848
1849        # register only a single backend when all get_device_backend_map values are the same
1850        if len(set(backend_config.get_device_backend_map().values())) == 1:
1851            for device in backend_config.get_device_backend_map().keys():
1852                pg._register_backend(torch.device(device), backend_type, backend_class)
1853
1854            # break out of outer loop to not create any more backends
1855            break
1856
1857        pg._register_backend(torch.device(device), backend_type, backend_class)
1858
1859    # set group_name and group_dsec to backend
1860    assert group_name is not None
1861    assert group_desc is not None
1862    pg._set_group_name(group_name)
1863    pg._set_group_desc(group_desc)
1864
1865    if device_id and pg._get_backend(device_id).supports_splitting:
1866        eager_backend = pg._get_backend(device_id)
1867        eager_backend.eager_connect_single_device(device_id)
1868
1869    # update global state
1870    _world.pg_map[pg] = (backend, prefix_store)
1871    _world.pg_names[pg] = group_name
1872    _register_process_group(group_name, pg)
1873
1874    _world.pg_backend_config[pg] = str(backend_config)
1875    # "" is the default tag for user PGs
1876    if pg_tag in [None, ""]:
1877        pg_tag = f"ptd:{group_name}"
1878        _world.tags_to_pg.setdefault("", []).append(pg)
1879    else:
1880        pg_tag = f"user:{pg_tag}"
1881
1882    _world.tags_to_pg.setdefault(pg_tag, []).append(pg)
1883    _world.pg_to_tag[pg] = pg_tag
1884    return pg, prefix_store
1885
1886
1887def destroy_process_group(group: Optional[ProcessGroup] = None):
1888    """
1889    Destroy a given process group, and deinitialize the distributed package.
1890
1891    Args:
1892        group (ProcessGroup, optional): The process group to be destroyed, if
1893                                        group.WORLD is given, all process
1894                                        groups including the default one will
1895                                        be destroyed.
1896    """
1897    global _world
1898
1899    if group == GroupMember.NON_GROUP_MEMBER:
1900        return
1901
1902    if group is None:
1903        pg = GroupMember.WORLD
1904    else:
1905        pg = group
1906
1907    assert pg is not None
1908    if _world.pg_map.get(pg, None) is None:
1909        raise ValueError("Invalid process group specified")
1910
1911    # When users register Python onCompletion hooks, those hooks will run on a
1912    # different thread than the main thread. Today, the ProcessGroup dtor does
1913    # wait for that thread. However, the dtor might finish after the Python
1914    # Interpreter exits. After that grabbing the GIL for the Python hook will crash.
1915    # We can either revive the interpreter when running hooks or keep the main one
1916    # alive until all works and hooks are done. The current implementation does the
1917    # latter. Therefore, we explicitly call _wait_for_pending_works() here to wait
1918    # for the pending hooks to finish.
1919    if pg.name().lower() == "nccl" and pg._has_hooks():
1920        pg._wait_for_pending_works()
1921
1922    if group is None or group == GroupMember.WORLD:
1923        # shutdown all backends in the order of pg names. shutting down in order because
1924        # ncclCommAbort() was a 'collective' call in some versions of NCCL.
1925        for pg_to_shutdown in sorted(
1926            _world.pg_names, key=lambda x: _world.pg_names[x], reverse=True
1927        ):
1928            _shutdown_backend(pg_to_shutdown)
1929
1930        _update_default_pg(None)
1931        _world.pg_map.clear()
1932        _world.pg_names.clear()
1933        _world.pg_group_ranks.clear()
1934        _world.pg_backend_config.clear()
1935        _world.pg_to_tag.clear()
1936        _world.tags_to_pg.clear()
1937        _world.pg_coalesce_state.clear()
1938        _world.pg_default_device.clear()
1939        _unregister_all_process_groups()
1940
1941        # when process group doesn't have an explicit name (only WORLD (default)
1942        # process group can have an explicit name), we use global _world.group_count
1943        # to generate the name. We need to reset the counter on destruction to
1944        # allow consistent value to be generated when we re-create process
1945        # groups after some trainers recover from failure
1946        #
1947        # We only reset this when WORLD is being destroyed because if this
1948        # process group is in good state, we aren't dealing with failures.
1949        _world.group_count = 0
1950    else:
1951        _shutdown_backend(pg)
1952        del _world.pg_map[pg]
1953        del _world.pg_names[pg]
1954        del _world.pg_group_ranks[pg]
1955        del _world.pg_backend_config[pg]
1956        if pg in _world.pg_default_device:
1957            del _world.pg_default_device[pg]
1958        if pg in _world.pg_coalesce_state.keys():
1959            warnings.warn(
1960                "Some coalesced collectives haven't been launched when "
1961                "ProcessGroup is destroyed. They will be cleaned."
1962            )
1963            del _world.pg_coalesce_state[pg]
1964
1965        tag = _world.pg_to_tag.get(pg)
1966        del _world.pg_to_tag[pg]
1967        if tag is not None:
1968            try:
1969                _world.tags_to_pg[tag].remove(pg)
1970                if tag.startswith("ptd:"):
1971                    _world.tags_to_pg[""].remove(pg)
1972            except Exception:
1973                pass
1974        _unregister_process_group(pg.group_name)
1975
1976
1977def get_rank(group: Optional[ProcessGroup] = None) -> int:
1978    """
1979    Return the rank of the current process in the provided ``group``, default otherwise.
1980
1981    Rank is a unique identifier assigned to each process within a distributed
1982    process group. They are always consecutive integers ranging from 0 to
1983    ``world_size``.
1984
1985    Args:
1986        group (ProcessGroup, optional): The process group to work on. If None,
1987            the default process group will be used.
1988
1989    Returns:
1990        The rank of the process group
1991        -1, if not part of the group
1992
1993    """
1994    if _rank_not_in_group(group):
1995        return -1
1996
1997    default_pg = _get_default_group()
1998    if group is None or group is GroupMember.WORLD:
1999        return default_pg.rank()
2000
2001    return get_group_rank(group, default_pg.rank())
2002
2003
2004def get_world_size(group: Optional[ProcessGroup] = None) -> int:
2005    """
2006    Return the number of processes in the current process group.
2007
2008    Args:
2009        group (ProcessGroup, optional): The process group to work on. If None,
2010            the default process group will be used.
2011
2012    Returns:
2013        The world size of the process group
2014        -1, if not part of the group
2015
2016    """
2017    if _rank_not_in_group(group):
2018        return -1
2019
2020    return _get_group_size(group)
2021
2022
2023def isend(
2024    tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0
2025) -> Optional[Work]:
2026    """
2027    Send a tensor asynchronously.
2028
2029    .. warning::
2030        Modifying ``tensor`` before the request completes causes undefined
2031        behavior.
2032
2033    .. warning::
2034        ``tag`` is not supported with the NCCL backend.
2035
2036    Args:
2037        tensor (Tensor): Tensor to send.
2038        dst (int): Destination rank on global process group (regardless of ``group`` argument)
2039        group (ProcessGroup, optional): The process group to work on. If None,
2040            the default process group will be used.
2041        tag (int, optional): Tag to match send with remote recv
2042
2043    Returns:
2044        A distributed request object.
2045        None, if not part of the group
2046
2047    """
2048    _check_single_tensor(tensor, "tensor")
2049    if _rank_not_in_group(group):
2050        _warn_not_in_group("isend")
2051        return None
2052
2053    if tensor.is_complex():
2054        tensor = torch.view_as_real(tensor)
2055
2056    if group is None or group is GroupMember.WORLD:
2057        pg = _get_default_group()
2058    else:
2059        pg = group
2060        dst = get_group_rank(pg, dst)
2061
2062    return pg.send([tensor], dst, tag)
2063
2064
2065def irecv(
2066    tensor: torch.Tensor,
2067    src: Optional[int] = None,
2068    group: Optional[ProcessGroup] = None,
2069    tag: int = 0,
2070) -> Optional[Work]:
2071    """
2072    Receives a tensor asynchronously.
2073
2074    .. warning::
2075        ``tag`` is not supported with the NCCL backend.
2076
2077    Args:
2078        tensor (Tensor): Tensor to fill with received data.
2079        src (int, optional): Source rank on global process group (regardless of ``group`` argument).
2080            Will receive from any process if unspecified.
2081        group (ProcessGroup, optional): The process group to work on. If None,
2082            the default process group will be used.
2083        tag (int, optional): Tag to match recv with remote send
2084
2085    Returns:
2086        A distributed request object.
2087        None, if not part of the group
2088
2089    """
2090    _check_single_tensor(tensor, "tensor")
2091    if _rank_not_in_group(group):
2092        _warn_not_in_group("irecv")
2093        return None
2094
2095    if tensor.is_complex():
2096        tensor = torch.view_as_real(tensor)
2097
2098    if group is None or group is GroupMember.WORLD:
2099        pg = _get_default_group()
2100    else:
2101        pg = group
2102
2103    if src is None:
2104        return pg.recv_anysource([tensor], tag)
2105    else:
2106        if pg is GroupMember.WORLD:
2107            return pg.recv([tensor], src, tag)
2108        else:
2109            group_src_rank = get_group_rank(pg, src)
2110            return pg.recv([tensor], group_src_rank, tag)
2111
2112
2113@_exception_logger
2114def send(
2115    tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0
2116) -> None:
2117    """
2118    Send a tensor synchronously.
2119
2120    .. warning::
2121        ``tag`` is not supported with the NCCL backend.
2122
2123    Args:
2124        tensor (Tensor): Tensor to send.
2125        dst (int): Destination rank on global process group (regardless of ``group`` argument).
2126            Destination rank should not be the same as the rank of the current process.
2127        group (ProcessGroup, optional): The process group to work on. If None,
2128            the default process group will be used.
2129        tag (int, optional): Tag to match send with remote recv
2130
2131    """
2132    if get_rank() == dst:
2133        raise ValueError(
2134            "Invalid destination rank: destination rank should not be the same as "
2135            "the rank of the current process."
2136        )
2137
2138    _check_single_tensor(tensor, "tensor")
2139    if _rank_not_in_group(group):
2140        _warn_not_in_group("send")
2141        return None
2142
2143    if tensor.is_complex():
2144        tensor = torch.view_as_real(tensor)
2145
2146    if group is None or group is GroupMember.WORLD:
2147        default_pg = _get_default_group()
2148        default_pg.send([tensor], dst, tag).wait()
2149    else:
2150        group_dst_rank = get_group_rank(group, dst)
2151        group.send([tensor], group_dst_rank, tag).wait()
2152
2153
2154@_exception_logger
2155def recv(
2156    tensor: torch.Tensor,
2157    src: Optional[int] = None,
2158    group: Optional[ProcessGroup] = None,
2159    tag: int = 0,
2160) -> int:
2161    """
2162    Receives a tensor synchronously.
2163
2164    .. warning::
2165        ``tag`` is not supported with the NCCL backend.
2166
2167    Args:
2168        tensor (Tensor): Tensor to fill with received data.
2169        src (int, optional): Source rank on global process group (regardless of ``group`` argument).
2170            Will receive from any process if unspecified.
2171        group (ProcessGroup, optional): The process group to work on. If None,
2172            the default process group will be used.
2173        tag (int, optional): Tag to match recv with remote send
2174
2175    Returns:
2176        Sender rank
2177        -1, if not part of the group
2178
2179    """
2180    _check_single_tensor(tensor, "tensor")
2181    if _rank_not_in_group(group):
2182        _warn_not_in_group("recv")
2183        return -1
2184
2185    if tensor.is_complex():
2186        tensor = torch.view_as_real(tensor)
2187
2188    pg = group or _get_default_group()
2189
2190    if src is None:
2191        work = pg.recv_anysource([tensor], tag)
2192        work.wait()
2193        src_rank = work._source_rank()
2194        if group is None or group is GroupMember.WORLD:
2195            return src_rank
2196        else:
2197            return get_global_rank(pg, src_rank)
2198    else:
2199        if group is None or group is GroupMember.WORLD:
2200            pg.recv([tensor], src, tag).wait()
2201        else:
2202            group_src_rank = get_group_rank(pg, src)
2203            pg.recv([tensor], group_src_rank, tag).wait()
2204        return src
2205
2206
2207class _IllegalWork(Work):
2208    def __getattribute__(self, name):
2209        if name in [
2210            "is_success",
2211            "exception",
2212            "wait",
2213            "source_rank",
2214            "_source_rank",
2215            "result",
2216            "synchronize",
2217        ]:
2218            raise ValueError(f"Illegal to call {name} on IllegalWork object")
2219
2220
2221class _CoalescingManager:
2222    def __init__(self) -> None:
2223        self.works: List[Work] = []
2224
2225    def append(self, work: Work):
2226        if work:
2227            self.works.append(work)
2228
2229    def wait(self):
2230        for work in self.works:
2231            work.wait()
2232
2233
2234@contextlib.contextmanager
2235def _coalescing_manager(
2236    group: Optional[ProcessGroup] = None,
2237    device: Optional[torch.device] = None,
2238    async_ops: Optional[bool] = False,
2239):
2240    """
2241    Context manager used to coalesce collectives or P2P operations when possible.
2242
2243    Args:
2244        group (`ProcessGroup`, optional): The process group to work on. If None,
2245            the default process group will be used.
2246        device (`torch.device`, optional): Default is None, set to a device if
2247            there isn't a `**_coalesced` implementation by the backend.
2248        async_ops (`bool`, optional): whether the coalesced ops are async ops.
2249
2250    Examples:
2251        >>> # xdoctest: +SKIP("no rank")
2252        >>> # Synchronous ops
2253        >>> with _coalescing_manager():
2254        >>>     for i in range(num_colls):
2255        >>>         dist.all_reduce(tensors[i])
2256        >>> # Asynchronous ops
2257        >>> with _coalescing_manager(async_ops=True) as cm:
2258        >>>     for i in range(num_colls):
2259        >>>         dist.all_reduce(tensors[i])
2260        >>> cm.wait()
2261
2262    .. warning::
2263       :func:`_coalescing_manager` currently do not support coalescing
2264       all-reduces with different reduce operators, e.g.  `ReduceOp.SUM` mixed
2265       with `ReduceOp.PRODUCT`.
2266    """
2267    group = group or _get_default_group()
2268    op_list = _world.pg_coalesce_state.setdefault(group, [])
2269    if op_list:
2270        raise ValueError(
2271            "ProcessGroup has non-empty op list at the start of coalescing"
2272        )
2273    if device:
2274        group._start_coalescing(device)
2275    cm = _CoalescingManager()
2276    yield cm
2277    op_list = _world.pg_coalesce_state.pop(group)
2278    if op_list:
2279        # Collectives supporting "Fast Path" coalescing are captured.
2280        # See implementation in corresponding collective APIs.
2281        # Currently supported:
2282        # - coalesced `all_reduce`
2283        # - coalesced `all_gather_into_tensor`
2284        # - coalesced `reduce_scatter_tensor`
2285        op0 = op_list[0].op
2286        if op0 == all_reduce:
2287            tensors = []
2288            for op in op_list:
2289                tensors.append(op.tensor)
2290            all_reduce_opts = AllreduceCoalescedOptions()
2291            all_reduce_opts.reduceOp = not_none(op_list[0].redop)
2292            work = group.allreduce_coalesced(tensors, all_reduce_opts)
2293        elif op0 == all_gather_into_tensor:
2294            inputs = []
2295            outputs = []
2296            for op in op_list:
2297                inputs.append(op.tensor)
2298                outputs.append(not_none(op.dst_tensor))
2299            work = group.allgather_into_tensor_coalesced(outputs, inputs)
2300        elif op0 == reduce_scatter_tensor:
2301            inputs = []
2302            outputs = []
2303            for op in op_list:
2304                inputs.append(op.tensor)
2305                outputs.append(not_none(op.dst_tensor))
2306            reduce_opts = ReduceScatterOptions()
2307            reduce_opts.reduceOp = not_none(op_list[0].redop)
2308            work = group.reduce_scatter_tensor_coalesced(outputs, inputs, reduce_opts)
2309        else:
2310            raise AssertionError(
2311                f"Coalescing manager does not support fast-path coalescing of {op0}, "
2312                f"yet {op0} is still recorded in op list. This is an internal error of c10d."
2313            )
2314
2315    if device:
2316        # Old style of letting each coll inside the context manager to call into C++ counterpart via python binding
2317        work = group._end_coalescing(device)
2318
2319    if async_ops:
2320        cm.append(work)  # type: ignore[possibly-undefined]
2321    else:
2322        work.wait()  # type: ignore[possibly-undefined]
2323
2324
2325def batch_isend_irecv(p2p_op_list):
2326    """
2327    Send or Receive a batch of tensors asynchronously and return a list of requests.
2328
2329    Process each of the operations in ``p2p_op_list`` and return the corresponding
2330    requests. NCCL, Gloo, and UCC backend are currently supported.
2331
2332    Args:
2333        p2p_op_list: A list of point-to-point operations(type of each operator is
2334            ``torch.distributed.P2POp``). The order of the isend/irecv in the list
2335            matters and it needs to match with corresponding isend/irecv on the
2336            remote end.
2337
2338    Returns:
2339        A list of distributed request objects returned by calling the corresponding
2340        op in the op_list.
2341
2342    Examples:
2343        >>> # xdoctest: +SKIP("no rank")
2344        >>> send_tensor = torch.arange(2, dtype=torch.float32) + 2 * rank
2345        >>> recv_tensor = torch.randn(2, dtype=torch.float32)
2346        >>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1)%world_size)
2347        >>> recv_op = dist.P2POp(dist.irecv, recv_tensor, (rank - 1 + world_size)%world_size)
2348        >>> reqs = batch_isend_irecv([send_op, recv_op])
2349        >>> for req in reqs:
2350        >>>     req.wait()
2351        >>> recv_tensor
2352        tensor([2, 3])     # Rank 0
2353        tensor([0, 1])     # Rank 1
2354
2355    .. note:: Note that when this API is used with the NCCL PG backend, users must set
2356        the current GPU device with `torch.cuda.set_device`, otherwise it will
2357        lead to unexpected hang issues.
2358
2359        In addition, if this API is the first collective call in the ``group``
2360        passed to ``dist.P2POp``, all ranks of the ``group`` must participate in
2361        this API call; otherwise, the behavior is undefined. If this API call is
2362        not the first collective call in the ``group``, batched P2P operations
2363        involving only a subset of ranks of the ``group`` are allowed.
2364    """
2365    _check_p2p_op_list(p2p_op_list)
2366    group = p2p_op_list[0].group
2367    device = p2p_op_list[0].tensor.device
2368    if device.type == "cuda":
2369        # NCCL style coalescing
2370        with _coalescing_manager(group, device, async_ops=True) as cm:
2371            for p2p_op in p2p_op_list:
2372                p2p_op.op(p2p_op.tensor, p2p_op.peer, p2p_op.group, p2p_op.tag)
2373        return cm.works
2374    else:
2375        # Backward support for Gloo
2376        reqs = []
2377        for p2p_op in p2p_op_list:
2378            work = p2p_op.op(p2p_op.tensor, p2p_op.peer, p2p_op.group, p2p_op.tag)
2379            if work:
2380                reqs.append(work)
2381        return reqs
2382
2383
2384@_exception_logger
2385def broadcast(tensor, src, group=None, async_op=False):
2386    """
2387    Broadcasts the tensor to the whole group.
2388
2389    ``tensor`` must have the same number of elements in all processes
2390    participating in the collective.
2391
2392    Args:
2393        tensor (Tensor): Data to be sent if ``src`` is the rank of current
2394            process, and tensor to be used to save received data otherwise.
2395        src (int): Source rank on global process group (regardless of ``group`` argument).
2396        group (ProcessGroup, optional): The process group to work on. If None,
2397            the default process group will be used.
2398        async_op (bool, optional): Whether this op should be an async op
2399
2400    Returns:
2401        Async work handle, if async_op is set to True.
2402        None, if not async_op or if not part of the group
2403
2404    """
2405    _check_single_tensor(tensor, "tensor")
2406    if _rank_not_in_group(group):
2407        _warn_not_in_group("broadcast")
2408        return
2409
2410    opts = BroadcastOptions()
2411    opts.rootRank = src
2412    opts.rootTensor = 0
2413    opts.asyncOp = async_op
2414
2415    if group is None or group is GroupMember.WORLD:
2416        default_pg = _get_default_group()
2417        work = default_pg.broadcast([tensor], opts)
2418    else:
2419        group_src_rank = get_group_rank(group, src)
2420        opts.rootRank = group_src_rank
2421        work = group.broadcast([tensor], opts)
2422    if async_op:
2423        return work
2424    else:
2425        work.wait()
2426
2427
2428@_exception_logger
2429def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
2430    """
2431    Reduces the tensor data across all machines in a way that all get the final result.
2432
2433    After the call ``tensor`` is going to be bitwise identical in all processes.
2434
2435    Complex tensors are supported.
2436
2437    Args:
2438        tensor (Tensor): Input and output of the collective. The function
2439            operates in-place.
2440        op (optional): One of the values from
2441            ``torch.distributed.ReduceOp``
2442            enum.  Specifies an operation used for element-wise reductions.
2443        group (ProcessGroup, optional): The process group to work on. If None,
2444            the default process group will be used.
2445        async_op (bool, optional): Whether this op should be an async op
2446
2447    Returns:
2448        Async work handle, if async_op is set to True.
2449        None, if not async_op or if not part of the group
2450
2451    Examples:
2452        >>> # xdoctest: +SKIP("no rank")
2453        >>> # All tensors below are of torch.int64 type.
2454        >>> # We have 2 process groups, 2 ranks.
2455        >>> device = torch.device(f'cuda:{rank}')
2456        >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
2457        >>> tensor
2458        tensor([1, 2], device='cuda:0') # Rank 0
2459        tensor([3, 4], device='cuda:1') # Rank 1
2460        >>> dist.all_reduce(tensor, op=ReduceOp.SUM)
2461        >>> tensor
2462        tensor([4, 6], device='cuda:0') # Rank 0
2463        tensor([4, 6], device='cuda:1') # Rank 1
2464
2465        >>> # All tensors below are of torch.cfloat type.
2466        >>> # We have 2 process groups, 2 ranks.
2467        >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j)
2468        >>> tensor
2469        tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
2470        tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
2471        >>> dist.all_reduce(tensor, op=ReduceOp.SUM)
2472        >>> tensor
2473        tensor([4.+4.j, 6.+6.j], device='cuda:0') # Rank 0
2474        tensor([4.+4.j, 6.+6.j], device='cuda:1') # Rank 1
2475
2476    """
2477    _check_single_tensor(tensor, "tensor")
2478    if _rank_not_in_group(group):
2479        _warn_not_in_group("all_reduce")
2480        return
2481
2482    if tensor.is_complex():
2483        if not supports_complex(op):
2484            raise ValueError(f"all_reduce does not support {op} on complex tensors")
2485        tensor = torch.view_as_real(tensor)
2486
2487    opts = AllreduceOptions()
2488    opts.reduceOp = op
2489    if group is None:
2490        group = _get_default_group()
2491
2492    if group in _world.pg_coalesce_state.keys():
2493        # We are in coalescing context, do not issue single operation, just append a collective representation
2494        coll = _CollOp(all_reduce, tensor, None, op, None)
2495        _world.pg_coalesce_state[group].append(coll)
2496        if async_op:
2497            return _IllegalWork()
2498        else:
2499            return None
2500
2501    work = group.allreduce([tensor], opts)
2502
2503    if async_op:
2504        return work
2505    else:
2506        work.wait()
2507
2508
2509@_exception_logger
2510@deprecated(
2511    "`torch.distributed.all_reduce_coalesced` will be deprecated. If you must "
2512    "use it, please revisit our documentation later at "
2513    "https://pytorch.org/docs/main/distributed.html#collective-functions",
2514    category=FutureWarning,
2515)
2516def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
2517    """
2518    WARNING: at this time individual shape checking is not implemented across nodes.
2519
2520    For example, if the rank 0 node passes [torch.rand(4), torch.rand(2)] and the
2521    rank 1 node passes [torch.rand(2), torch.rand(2), torch.rand(2)], the allreduce
2522    operation will proceed without complaint and return erroneous outputs. This lack
2523    of shape checking results in significant performance improvements but users of this
2524    function should take extra care to ensure that each node passes in tensors whose
2525    shapes match across nodes.
2526
2527    Reduces each tensor in tensors (residing on the same device) across all machines
2528    in such a way that all get the final result.
2529
2530    After the call each tensor in tensors is going to bitwise identical
2531    in all processes.
2532
2533    Complex tensors are supported.
2534
2535    Args:
2536        tensors (Union[List[Tensor], Tensor]): Input and output of the collective.
2537            The function operates in-place.
2538        op (Optional[ReduceOp]): One of the values from
2539            ``torch.distributed.ReduceOp`` enum. Specifies an operation used for
2540            element-wise reductions.
2541        group (ProcessGroup, optional): The process group to work on. If None,
2542            the default process group will be used.
2543        async_op (Optional[bool]): Whether this op should be an async op.
2544
2545    Returns:
2546        Async work handle, if async_op is set to True.
2547        None, if not async_op or if not part of the group.
2548
2549    """
2550    if isinstance(tensors, torch.Tensor):
2551        tensors = [tensors]
2552    _check_tensor_list(tensors, "tensor")
2553    _ensure_all_tensors_same_dtype(tensors)
2554    if _rank_not_in_group(group):
2555        _warn_not_in_group("all_reduce_coalesced")
2556        return
2557
2558    if any(t.is_complex() for t in tensors) and not supports_complex(op):
2559        raise ValueError(f"all_reduce does not support {op} on complex tensors")
2560
2561    tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors]
2562
2563    opts = AllreduceCoalescedOptions()
2564    opts.reduceOp = op
2565    group = group or _get_default_group()
2566    work = group.allreduce_coalesced(tensors, opts)
2567
2568    if async_op:
2569        return work.get_future()
2570    else:
2571        work.wait()
2572
2573
2574@_exception_logger
2575def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
2576    """
2577    Reduces the tensor data across all machines.
2578
2579    Only the process with rank ``dst`` is going to receive the final result.
2580
2581    Args:
2582        tensor (Tensor): Input and output of the collective. The function
2583            operates in-place.
2584        dst (int): Destination rank on global process group (regardless of ``group`` argument)
2585        op (optional): One of the values from
2586            ``torch.distributed.ReduceOp``
2587            enum.  Specifies an operation used for element-wise reductions.
2588        group (ProcessGroup, optional): The process group to work on. If None,
2589            the default process group will be used.
2590        async_op (bool, optional): Whether this op should be an async op
2591
2592    Returns:
2593        Async work handle, if async_op is set to True.
2594        None, if not async_op or if not part of the group
2595
2596    """
2597    _check_single_tensor(tensor, "tensor")
2598    if _rank_not_in_group(group):
2599        _warn_not_in_group("reduce")
2600        return
2601
2602    opts = ReduceOptions()
2603    opts.reduceOp = op
2604    opts.rootRank = dst
2605
2606    if group is None or group is GroupMember.WORLD:
2607        default_pg = _get_default_group()
2608        work = default_pg.reduce([tensor], opts)
2609    else:
2610        group_dst_rank = get_group_rank(group, dst)
2611        opts.rootRank = group_dst_rank
2612        work = group.reduce([tensor], opts)
2613
2614    if async_op:
2615        return work
2616    else:
2617        work.wait()
2618
2619
2620def _object_to_tensor(obj, device, group):
2621    f = io.BytesIO()
2622    _pickler(f).dump(obj)
2623    byte_storage = torch.ByteStorage._from_buffer(f.getvalue())  # type: ignore[attr-defined]
2624    # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
2625    # Otherwise, it will casue 100X slowdown.
2626    # See: https://github.com/pytorch/pytorch/issues/65696
2627    byte_tensor = torch.ByteTensor(byte_storage).to(device)
2628    if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
2629        backend = get_backend(group)
2630        if backend == Backend.NCCL:
2631            hash = torch._C._distributed_c10d._hash_tensors([byte_tensor])
2632            logger.warning(
2633                "_object_to_tensor size: %s hash value: %s", byte_tensor.numel(), hash
2634            )
2635    local_size = torch.LongTensor([byte_tensor.numel()]).to(device)
2636    return byte_tensor, local_size
2637
2638
2639def _tensor_to_object(tensor, tensor_size, group):
2640    if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
2641        backend = get_backend(group)
2642        if backend == Backend.NCCL:
2643            hash = torch._C._distributed_c10d._hash_tensors([tensor])
2644            logger.warning(
2645                "_tensor_to_object size: %s hash value: %s", tensor.numel(), hash
2646            )
2647    tensor = tensor.cpu()
2648    buf = tensor.numpy().tobytes()[:tensor_size]
2649    return _unpickler(io.BytesIO(buf)).load()
2650
2651
2652@_exception_logger
2653def all_gather_object(object_list, obj, group=None):
2654    """
2655    Gathers picklable objects from the whole group into a list.
2656
2657    Similar to :func:`all_gather`, but Python objects can be passed in.
2658    Note that the object must be picklable in order to be gathered.
2659
2660    Args:
2661        object_list (list[Any]): Output list. It should be correctly sized as the
2662            size of the group for this collective and will contain the output.
2663        obj (Any): Pickable Python object to be broadcast from current process.
2664        group (ProcessGroup, optional): The process group to work on. If None,
2665            the default process group will be used. Default is ``None``.
2666
2667    Returns:
2668        None. If the calling rank is part of this group, the output of the
2669        collective will be populated into the input ``object_list``. If the
2670        calling rank is not part of the group, the passed in ``object_list`` will
2671        be unmodified.
2672
2673    .. note:: Note that this API differs slightly from the :func:`all_gather`
2674        collective since it does not provide an ``async_op`` handle and thus
2675        will be a blocking call.
2676
2677    .. note:: For NCCL-based processed groups, internal tensor representations
2678        of objects must be moved to the GPU device before communication takes
2679        place. In this case, the device used is given by
2680        ``torch.cuda.current_device()`` and it is the user's responsiblity to
2681        ensure that this is set so that each rank has an individual GPU, via
2682        ``torch.cuda.set_device()``.
2683
2684    .. warning::
2685        :func:`all_gather_object` uses ``pickle`` module implicitly, which is
2686        known to be insecure. It is possible to construct malicious pickle data
2687        which will execute arbitrary code during unpickling. Only call this
2688        function with data you trust.
2689
2690    .. warning::
2691        Calling :func:`all_gather_object` with GPU tensors is not well supported
2692        and inefficient as it incurs GPU -> CPU transfer since tensors would be
2693        pickled. Please consider using :func:`all_gather` instead.
2694
2695    Example::
2696        >>> # xdoctest: +SKIP("need process group init")
2697        >>> # Note: Process group initialization omitted on each rank.
2698        >>> import torch.distributed as dist
2699        >>> # Assumes world_size of 3.
2700        >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
2701        >>> output = [None for _ in gather_objects]
2702        >>> dist.all_gather_object(output, gather_objects[dist.get_rank()])
2703        >>> output
2704        ['foo', 12, {1: 2}]
2705    """
2706    if _rank_not_in_group(group):
2707        _warn_not_in_group("all_gather_object")
2708        return
2709
2710    current_device = _get_pg_default_device(group)
2711    input_tensor, local_size = _object_to_tensor(obj, current_device, group)
2712
2713    # Gather all local sizes. This is so that we can find the max size, and index
2714    # until the correct size when deserializing the tensors.
2715    group_size = get_world_size(group=group)
2716    object_sizes_tensor = torch.zeros(
2717        group_size, dtype=torch.long, device=current_device
2718    )
2719    object_size_list = [
2720        object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
2721    ]
2722    # Allgather tensor sizes
2723    all_gather(object_size_list, local_size, group=group)
2724    max_object_size = int(max(object_size_list).item())  # type: ignore[type-var]
2725    # Resize tensor to max size across all ranks.
2726    input_tensor.resize_(max_object_size)
2727    coalesced_output_tensor = torch.empty(
2728        max_object_size * group_size, dtype=torch.uint8, device=current_device
2729    )
2730    # Output tensors are nonoverlapping views of coalesced_output_tensor
2731    output_tensors = [
2732        coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
2733        for i in range(group_size)
2734    ]
2735    all_gather(output_tensors, input_tensor, group=group)
2736    # Deserialize outputs back to object.
2737    for i, tensor in enumerate(output_tensors):
2738        tensor = tensor.type(torch.uint8)
2739        tensor_size = object_size_list[i]
2740        object_list[i] = _tensor_to_object(tensor, tensor_size, group)
2741
2742
2743@_exception_logger
2744def gather_object(obj, object_gather_list=None, dst=0, group=None):
2745    """
2746    Gathers picklable objects from the whole group in a single process.
2747
2748    Similar to :func:`gather`, but Python objects can be passed in. Note that the
2749    object must be picklable in order to be gathered.
2750
2751    Args:
2752        obj (Any): Input object. Must be picklable.
2753        object_gather_list (list[Any]): Output list. On the ``dst`` rank, it
2754            should be correctly sized as the size of the group for this
2755            collective and will contain the output. Must be ``None`` on non-dst
2756            ranks. (default is ``None``)
2757        dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). (default is 0)
2758        group: (ProcessGroup, optional): The process group to work on. If None,
2759            the default process group will be used. Default is ``None``.
2760
2761    Returns:
2762        None. On the ``dst`` rank, ``object_gather_list`` will contain the
2763        output of the collective.
2764
2765    .. note:: Note that this API differs slightly from the gather collective
2766        since it does not provide an async_op handle and thus will be a blocking
2767        call.
2768
2769    .. note:: For NCCL-based processed groups, internal tensor representations
2770        of objects must be moved to the GPU device before communication takes
2771        place. In this case, the device used is given by
2772        ``torch.cuda.current_device()`` and it is the user's responsiblity to
2773        ensure that this is set so that each rank has an individual GPU, via
2774        ``torch.cuda.set_device()``.
2775
2776    .. warning::
2777        :func:`gather_object` uses ``pickle`` module implicitly, which is
2778        known to be insecure. It is possible to construct malicious pickle data
2779        which will execute arbitrary code during unpickling. Only call this
2780        function with data you trust.
2781
2782    .. warning::
2783        Calling :func:`gather_object` with GPU tensors is not well supported
2784        and inefficient as it incurs GPU -> CPU transfer since tensors would be
2785        pickled. Please consider using :func:`gather` instead.
2786
2787    Example::
2788        >>> # xdoctest: +SKIP("need process group init")
2789        >>> # Note: Process group initialization omitted on each rank.
2790        >>> import torch.distributed as dist
2791        >>> # Assumes world_size of 3.
2792        >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
2793        >>> output = [None for _ in gather_objects]
2794        >>> dist.gather_object(
2795        ...     gather_objects[dist.get_rank()],
2796        ...     output if dist.get_rank() == 0 else None,
2797        ...     dst=0
2798        ... )
2799        >>> # On rank 0
2800        >>> output
2801        ['foo', 12, {1: 2}]
2802    """
2803    if _rank_not_in_group(group):
2804        _warn_not_in_group("gather_object")
2805        return
2806
2807    # Ensure object_gather_list is specified appropriately.
2808    my_rank = get_rank()
2809    _validate_output_list_for_rank(my_rank, dst, object_gather_list)
2810    current_device = _get_pg_default_device(group)
2811    input_tensor, local_size = _object_to_tensor(obj, current_device, group)
2812
2813    # Gather all local sizes. This is so that we can find the max size, and index
2814    # until the correct size when deserializing the tensors.
2815    group_size = get_world_size(group=group)
2816    object_sizes_tensor = torch.zeros(
2817        group_size, dtype=torch.long, device=current_device
2818    )
2819    object_size_list = [
2820        object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
2821    ]
2822    # Allgather tensor sizes. An all-gather is needed here despite this being a
2823    # gather, since each rank needs to broadcast a tensor of the same (maximal)
2824    # size.
2825    all_gather(object_size_list, local_size, group=group)
2826    max_object_size = int(max(object_size_list).item())  # type: ignore[type-var]
2827    # Resize tensor to max size across all ranks.
2828    input_tensor.resize_(max_object_size)
2829    # Avoid populating output tensors if the result won't be gathered on this rank.
2830    if my_rank == dst:
2831        coalesced_output_tensor = torch.empty(
2832            max_object_size * group_size, dtype=torch.uint8, device=current_device
2833        )
2834        # Output tensors are nonoverlapping views of coalesced_output_tensor
2835        output_tensors = [
2836            coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
2837            for i in range(group_size)
2838        ]
2839    # All ranks call gather with equal-sized tensors.
2840    gather(
2841        input_tensor,
2842        gather_list=output_tensors if my_rank == dst else None,  # type: ignore[possibly-undefined]
2843        dst=dst,
2844        group=group,
2845    )
2846    if my_rank != dst:
2847        return
2848    for i, tensor in enumerate(output_tensors):
2849        tensor = tensor.type(torch.uint8)
2850        tensor_size = object_size_list[i]
2851        object_gather_list[i] = _tensor_to_object(tensor, tensor_size, group)
2852
2853
2854@_exception_logger
2855def send_object_list(object_list, dst, group=None, device=None):
2856    """
2857    Sends picklable objects in ``object_list`` synchronously.
2858
2859    Similar to :func:`send`, but Python objects can be passed in.
2860    Note that all objects in ``object_list`` must be picklable in order to be
2861    sent.
2862
2863    Args:
2864        object_list (List[Any]): List of input objects to sent.
2865            Each object must be picklable. Receiver must provide lists of equal sizes.
2866        dst (int): Destination rank to send ``object_list`` to.
2867            Destination rank is based on global process group (regardless of ``group`` argument)
2868        group: (ProcessGroup, optional): The process group to work on. If None,
2869            the default process group will be used. Default is ``None``.
2870        device (``torch.device``, optional): If not None, the objects are
2871            serialized and converted to tensors which are moved to the
2872            ``device`` before sending. Default is ``None``.
2873
2874    Returns:
2875        ``None``.
2876
2877    .. note:: For NCCL-based process groups, internal tensor representations
2878        of objects must be moved to the GPU device before communication takes
2879        place. In this case, the device used is given by
2880        ``torch.cuda.current_device()`` and it is the user's responsibility to
2881        ensure that this is set so that each rank has an individual GPU, via
2882        ``torch.cuda.set_device()``.
2883
2884    .. warning::
2885        :func:`send_object_list` uses ``pickle`` module implicitly, which
2886        is known to be insecure. It is possible to construct malicious pickle
2887        data which will execute arbitrary code during unpickling. Only call this
2888        function with data you trust.
2889
2890    .. warning::
2891        Calling :func:`send_object_list` with GPU tensors is not well supported
2892        and inefficient as it incurs GPU -> CPU transfer since tensors would be
2893        pickled. Please consider using :func:`send` instead.
2894
2895    Example::
2896        >>> # xdoctest: +SKIP("need process group init")
2897        >>> # Note: Process group initialization omitted on each rank.
2898        >>> import torch.distributed as dist
2899        >>> # Assumes backend is not NCCL
2900        >>> device = torch.device("cpu")
2901        >>> if dist.get_rank() == 0:
2902        >>>     # Assumes world_size of 2.
2903        >>>     objects = ["foo", 12, {1: 2}] # any picklable object
2904        >>>     dist.send_object_list(objects, dst=1, device=device)
2905        >>> else:
2906        >>>     objects = [None, None, None]
2907        >>>     dist.recv_object_list(objects, src=0, device=device)
2908        >>> objects
2909        ['foo', 12, {1: 2}]
2910    """
2911    if get_rank() == dst:
2912        raise ValueError(
2913            "Invalid destination rank: destination rank should not be the same as "
2914            "the rank of the current process."
2915        )
2916
2917    if _rank_not_in_group(group):
2918        _warn_not_in_group("send_object_list")
2919        return
2920
2921    # Current device selection.
2922    # To preserve backwards compatibility, ``device`` is default to ``None``
2923    # in which case we run current logic of device selection, i.e.
2924    # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the
2925    # case it is not ``None`` we move the size and object tensors to be
2926    # sent to this device.
2927    current_device = device or _get_pg_default_device(group)
2928    # Serialize object_list elements to tensors on src rank.
2929    tensor_list, size_list = zip(
2930        *[_object_to_tensor(obj, current_device, group) for obj in object_list]
2931    )
2932    object_sizes_tensor = torch.cat(size_list)
2933
2934    # Send object sizes
2935    send(object_sizes_tensor, dst=dst, group=group)
2936
2937    # Concatenate and send serialized object tensors
2938    # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list
2939    # has only one element, we can skip the copy.
2940    if len(tensor_list) == 1:  # type: ignore[possibly-undefined]
2941        object_tensor = tensor_list[0]
2942    else:
2943        object_tensor = torch.cat(tensor_list)
2944
2945    send(object_tensor, dst=dst, group=group)
2946
2947
2948@_exception_logger
2949def recv_object_list(object_list, src=None, group=None, device=None):
2950    """
2951    Receives picklable objects in ``object_list`` synchronously.
2952
2953    Similar to :func:`recv`, but can receive Python objects.
2954
2955    Args:
2956        object_list (List[Any]): List of objects to receive into.
2957            Must provide a list of sizes equal to the size of the list being sent.
2958        src (int, optional): Source rank from which to recv ``object_list``.
2959            Source rank is based on global process group (regardless of ``group`` argument)
2960            Will receive from any rank if set to None. Default is ``None``.
2961        group: (ProcessGroup, optional): The process group to work on. If None,
2962            the default process group will be used. Default is ``None``.
2963        device (``torch.device``, optional): If not None, receives on this device.
2964            Default is ``None``.
2965
2966    Returns:
2967        Sender rank. -1 if rank is not part of the group. If rank is part of the group,
2968        ``object_list`` will contain the sent objects from ``src`` rank.
2969
2970    .. note:: For NCCL-based process groups, internal tensor representations
2971        of objects must be moved to the GPU device before communication takes
2972        place. In this case, the device used is given by
2973        ``torch.cuda.current_device()`` and it is the user's responsibility to
2974        ensure that this is set so that each rank has an individual GPU, via
2975        ``torch.cuda.set_device()``.
2976
2977    .. warning::
2978        :func:`recv_object_list` uses ``pickle`` module implicitly, which
2979        is known to be insecure. It is possible to construct malicious pickle
2980        data which will execute arbitrary code during unpickling. Only call this
2981        function with data you trust.
2982
2983    .. warning::
2984        Calling :func:`recv_object_list` with GPU tensors is not well supported
2985        and inefficient as it incurs GPU -> CPU transfer since tensors would be
2986        pickled. Please consider using :func:`recv` instead.
2987
2988    Example::
2989        >>> # xdoctest: +SKIP("need process group init")
2990        >>> # Note: Process group initialization omitted on each rank.
2991        >>> import torch.distributed as dist
2992        >>> # Assumes backend is not NCCL
2993        >>> device = torch.device("cpu")
2994        >>> if dist.get_rank() == 0:
2995        >>>     # Assumes world_size of 2.
2996        >>>     objects = ["foo", 12, {1: 2}] # any picklable object
2997        >>>     dist.send_object_list(objects, dst=1, device=device)
2998        >>> else:
2999        >>>     objects = [None, None, None]
3000        >>>     dist.recv_object_list(objects, src=0, device=device)
3001        >>> objects
3002        ['foo', 12, {1: 2}]
3003    """
3004    if _rank_not_in_group(group):
3005        _warn_not_in_group("recv_object_list")
3006        return -1
3007
3008    # Current device selection.
3009    # To preserve backwards compatibility, ``device`` is default to ``None``
3010    # in which case we run current logic of device selection, i.e.
3011    # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the
3012    # case it is not ``None`` we move the size and object tensors to be
3013    # received to this device.
3014    current_device = device or _get_pg_default_device(group)
3015    object_sizes_tensor = torch.empty(
3016        len(object_list), dtype=torch.long, device=current_device
3017    )
3018
3019    # Receive object sizes
3020    rank_sizes = recv(object_sizes_tensor, src=src, group=group)
3021
3022    # Tensor to receive serialized objects into.
3023    object_tensor = torch.empty(  # type: ignore[call-overload]
3024        torch.sum(object_sizes_tensor).item(),  # type: ignore[arg-type]
3025        dtype=torch.uint8,
3026        device=current_device,
3027    )
3028
3029    rank_objects = recv(object_tensor, src=src, group=group)
3030    assert (
3031        rank_sizes == rank_objects
3032    ), "Mismatch in return ranks for object sizes and objects."
3033    # Deserialize objects using their stored sizes.
3034    offset = 0
3035    for i, obj_size in enumerate(object_sizes_tensor):
3036        obj_view = object_tensor[offset : offset + obj_size]
3037        obj_view = obj_view.type(torch.uint8)
3038        offset += obj_size
3039        object_list[i] = _tensor_to_object(obj_view, obj_size, group)
3040    return rank_objects
3041
3042
3043@_exception_logger
3044def broadcast_object_list(object_list, src=0, group=None, device=None):
3045    """
3046    Broadcasts picklable objects in ``object_list`` to the whole group.
3047
3048    Similar to :func:`broadcast`, but Python objects can be passed in.
3049    Note that all objects in ``object_list`` must be picklable in order to be
3050    broadcasted.
3051
3052    Args:
3053        object_list (List[Any]): List of input objects to broadcast.
3054            Each object must be picklable. Only objects on the ``src`` rank will
3055            be broadcast, but each rank must provide lists of equal sizes.
3056        src (int): Source rank from which to broadcast ``object_list``.
3057            Source rank is based on global process group (regardless of ``group`` argument)
3058        group: (ProcessGroup, optional): The process group to work on. If None,
3059            the default process group will be used. Default is ``None``.
3060        device (``torch.device``, optional): If not None, the objects are
3061            serialized and converted to tensors which are moved to the
3062            ``device`` before broadcasting. Default is ``None``.
3063
3064    Returns:
3065        ``None``. If rank is part of the group, ``object_list`` will contain the
3066        broadcasted objects from ``src`` rank.
3067
3068    .. note:: For NCCL-based process groups, internal tensor representations
3069        of objects must be moved to the GPU device before communication takes
3070        place. In this case, the device used is given by
3071        ``torch.cuda.current_device()`` and it is the user's responsibility to
3072        ensure that this is set so that each rank has an individual GPU, via
3073        ``torch.cuda.set_device()``.
3074
3075    .. note:: Note that this API differs slightly from the :func:`broadcast`
3076        collective since it does not provide an ``async_op`` handle and thus
3077        will be a blocking call.
3078
3079    .. warning::
3080        :func:`broadcast_object_list` uses ``pickle`` module implicitly, which
3081        is known to be insecure. It is possible to construct malicious pickle
3082        data which will execute arbitrary code during unpickling. Only call this
3083        function with data you trust.
3084
3085    .. warning::
3086        Calling :func:`broadcast_object_list` with GPU tensors is not well supported
3087        and inefficient as it incurs GPU -> CPU transfer since tensors would be
3088        pickled. Please consider using :func:`broadcast` instead.
3089
3090    Example::
3091        >>> # xdoctest: +SKIP("need process group init")
3092        >>> # Note: Process group initialization omitted on each rank.
3093        >>> import torch.distributed as dist
3094        >>> if dist.get_rank() == 0:
3095        >>>     # Assumes world_size of 3.
3096        >>>     objects = ["foo", 12, {1: 2}] # any picklable object
3097        >>> else:
3098        >>>     objects = [None, None, None]
3099        >>> # Assumes backend is not NCCL
3100        >>> device = torch.device("cpu")
3101        >>> dist.broadcast_object_list(objects, src=0, device=device)
3102        >>> objects
3103        ['foo', 12, {1: 2}]
3104    """
3105    if _rank_not_in_group(group):
3106        _warn_not_in_group("broadcast_object_list")
3107        return
3108
3109    # Current device selection.
3110    # To preserve backwards compatibility, ``device`` is default to ``None``
3111    # in which case we run current logic of device selection, i.e.
3112    # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the
3113    # case it is not ``None`` we move the size and object tensors to be
3114    # broadcasted to this device.
3115    current_device = device or _get_pg_default_device(group)
3116    my_rank = get_rank()
3117    # Serialize object_list elements to tensors on src rank.
3118    if my_rank == src:
3119        tensor_list, size_list = zip(
3120            *[_object_to_tensor(obj, current_device, group) for obj in object_list]
3121        )
3122        object_sizes_tensor = torch.cat(size_list)
3123    else:
3124        object_sizes_tensor = torch.empty(
3125            len(object_list), dtype=torch.long, device=current_device
3126        )
3127
3128    # Broadcast object sizes
3129    broadcast(object_sizes_tensor, src=src, group=group)
3130
3131    # Concatenate and broadcast serialized object tensors
3132    # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list
3133    # has only one element, we can skip the copy.
3134    if my_rank == src:
3135        if len(tensor_list) == 1:  # type: ignore[possibly-undefined]
3136            object_tensor = tensor_list[0]
3137        else:
3138            object_tensor = torch.cat(tensor_list)
3139    else:
3140        object_tensor = torch.empty(  # type: ignore[call-overload]
3141            torch.sum(object_sizes_tensor).item(),  # type: ignore[arg-type]
3142            dtype=torch.uint8,
3143            device=current_device,
3144        )
3145
3146    broadcast(object_tensor, src=src, group=group)
3147    # Deserialize objects using their stored sizes.
3148    offset = 0
3149    if my_rank != src:
3150        for i, obj_size in enumerate(object_sizes_tensor):
3151            obj_view = object_tensor[offset : offset + obj_size]
3152            obj_view = obj_view.type(torch.uint8)
3153            offset += obj_size
3154            object_list[i] = _tensor_to_object(obj_view, obj_size, group)
3155
3156
3157@_exception_logger
3158def scatter_object_list(
3159    scatter_object_output_list, scatter_object_input_list, src=0, group=None
3160):
3161    """
3162    Scatters picklable objects in ``scatter_object_input_list`` to the whole group.
3163
3164    Similar to :func:`scatter`, but Python objects can be passed in. On
3165    each rank, the scattered object will be stored as the first element of
3166    ``scatter_object_output_list``. Note that all objects in
3167    ``scatter_object_input_list`` must be picklable in order to be scattered.
3168
3169    Args:
3170        scatter_object_output_list (List[Any]): Non-empty list whose first
3171            element will store the object scattered to this rank.
3172        scatter_object_input_list (List[Any]): List of input objects to scatter.
3173            Each object must be picklable. Only objects on the ``src`` rank will
3174            be scattered, and the argument can be ``None`` for non-src ranks.
3175        src (int): Source rank from which to scatter ``scatter_object_input_list``.
3176            Source rank is based on global process group (regardless of ``group`` argument).
3177        group: (ProcessGroup, optional): The process group to work on. If None,
3178            the default process group will be used. Default is ``None``.
3179
3180    Returns:
3181        ``None``. If rank is part of the group, ``scatter_object_output_list``
3182        will have its first element set to the scattered object for this rank.
3183
3184    .. note:: Note that this API differs slightly from the scatter collective
3185        since it does not provide an ``async_op`` handle and thus will be a
3186        blocking call.
3187
3188    .. warning::
3189        :func:`scatter_object_list` uses ``pickle`` module implicitly, which
3190        is known to be insecure. It is possible to construct malicious pickle
3191        data which will execute arbitrary code during unpickling. Only call this
3192        function with data you trust.
3193
3194    .. warning::
3195        Calling :func:`scatter_object_list` with GPU tensors is not well supported
3196        and inefficient as it incurs GPU -> CPU transfer since tensors would be
3197        pickled. Please consider using :func:`scatter` instead.
3198
3199    Example::
3200        >>> # xdoctest: +SKIP("need process group init")
3201        >>> # Note: Process group initialization omitted on each rank.
3202        >>> import torch.distributed as dist
3203        >>> if dist.get_rank() == 0:
3204        >>>     # Assumes world_size of 3.
3205        >>>     objects = ["foo", 12, {1: 2}] # any picklable object
3206        >>> else:
3207        >>>     # Can be any list on non-src ranks, elements are not used.
3208        >>>     objects = [None, None, None]
3209        >>> output_list = [None]
3210        >>> dist.scatter_object_list(output_list, objects, src=0)
3211        >>> # Rank i gets objects[i]. For example, on rank 2:
3212        >>> output_list
3213        [{1: 2}]
3214    """
3215    if _rank_not_in_group(group):
3216        _warn_not_in_group("scatter_object_list")
3217        return
3218
3219    if (
3220        not isinstance(scatter_object_output_list, list)
3221        or len(scatter_object_output_list) < 1
3222    ):
3223        raise ValueError(
3224            "Expected argument scatter_object_output_list to be a list of size at least 1."
3225        )
3226
3227    my_rank = get_rank()
3228    pg_device = _get_pg_default_device(group)
3229    if my_rank == src:
3230        tensor_list, tensor_sizes = zip(
3231            *[
3232                _object_to_tensor(obj, pg_device, group)
3233                for obj in scatter_object_input_list
3234            ]
3235        )
3236        tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes)
3237
3238    # Src rank broadcasts the maximum tensor size. This is because all ranks are
3239    # expected to call into scatter() with equal-sized tensors.
3240    if my_rank == src:
3241        max_tensor_size = max(tensor_sizes)  # type: ignore[possibly-undefined]
3242        for tensor in tensor_list:  # type: ignore[possibly-undefined]
3243            tensor.resize_(max_tensor_size)
3244    else:
3245        max_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
3246    broadcast(max_tensor_size, src=src, group=group)
3247
3248    # Scatter actual serialized objects
3249    output_tensor = torch.empty(
3250        max_tensor_size.item(), dtype=torch.uint8, device=pg_device
3251    )
3252    scatter(
3253        output_tensor,
3254        scatter_list=None if my_rank != src else tensor_list,  # type: ignore[possibly-undefined]
3255        src=src,
3256        group=group,
3257    )
3258
3259    # Scatter per-object sizes to trim tensors when deserializing back to object
3260    obj_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
3261    scatter(
3262        obj_tensor_size,
3263        scatter_list=None if my_rank != src else tensor_sizes,  # type: ignore[possibly-undefined]
3264        src=src,
3265        group=group,
3266    )
3267
3268    # Deserialize back to object
3269    scatter_object_output_list[0] = _tensor_to_object(
3270        output_tensor, obj_tensor_size, group
3271    )
3272
3273
3274@_exception_logger
3275def all_gather(tensor_list, tensor, group=None, async_op=False):
3276    """
3277    Gathers tensors from the whole group in a list.
3278
3279    Complex and uneven sized tensors are supported.
3280
3281    Args:
3282        tensor_list (list[Tensor]): Output list. It should contain
3283            correctly-sized tensors to be used for output of the collective.
3284            Uneven sized tensors are supported.
3285        tensor (Tensor): Tensor to be broadcast from current process.
3286        group (ProcessGroup, optional): The process group to work on. If None,
3287            the default process group will be used.
3288        async_op (bool, optional): Whether this op should be an async op
3289
3290    Returns:
3291        Async work handle, if async_op is set to True.
3292        None, if not async_op or if not part of the group
3293
3294    Examples:
3295        >>> # xdoctest: +SKIP("need process group init")
3296        >>> # All tensors below are of torch.int64 dtype.
3297        >>> # We have 2 process groups, 2 ranks.
3298        >>> device = torch.device(f'cuda:{rank}')
3299        >>> tensor_list = [torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2)]
3300        >>> tensor_list
3301        [tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0
3302        [tensor([0, 0], device='cuda:1'), tensor([0, 0], device='cuda:1')] # Rank 1
3303        >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
3304        >>> tensor
3305        tensor([1, 2], device='cuda:0') # Rank 0
3306        tensor([3, 4], device='cuda:1') # Rank 1
3307        >>> dist.all_gather(tensor_list, tensor)
3308        >>> tensor_list
3309        [tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')] # Rank 0
3310        [tensor([1, 2], device='cuda:1'), tensor([3, 4], device='cuda:1')] # Rank 1
3311
3312        >>> # All tensors below are of torch.cfloat dtype.
3313        >>> # We have 2 process groups, 2 ranks.
3314        >>> tensor_list = [torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2)]
3315        >>> tensor_list
3316        [tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0
3317        [tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1
3318        >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j)
3319        >>> tensor
3320        tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
3321        tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
3322        >>> dist.all_gather(tensor_list, tensor)
3323        >>> tensor_list
3324        [tensor([1.+1.j, 2.+2.j], device='cuda:0'), tensor([3.+3.j, 4.+4.j], device='cuda:0')] # Rank 0
3325        [tensor([1.+1.j, 2.+2.j], device='cuda:1'), tensor([3.+3.j, 4.+4.j], device='cuda:1')] # Rank 1
3326
3327    """
3328    _check_tensor_list(tensor_list, "tensor_list")
3329    _check_single_tensor(tensor, "tensor")
3330    _ensure_all_tensors_same_dtype(tensor_list, tensor)
3331    if _rank_not_in_group(group):
3332        _warn_not_in_group("all_gather")
3333        return
3334
3335    tensor_list = [
3336        t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list
3337    ]
3338    tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)
3339
3340    group = group or _get_default_group()
3341    work = group.allgather([tensor_list], [tensor])
3342
3343    if async_op:
3344        return work
3345    else:
3346        work.wait()
3347
3348
3349@_exception_logger
3350def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=False):
3351    """
3352    Gather tensors from all ranks and put them in a single output tensor.
3353
3354    This function requires all tensors to be the same size on each process.
3355
3356    Args:
3357        output_tensor (Tensor): Output tensor to accommodate tensor elements
3358            from all ranks. It must be correctly sized to have one of the
3359            following forms:
3360            (i) a concatenation of all the input tensors along the primary
3361            dimension; for definition of "concatenation", see ``torch.cat()``;
3362            (ii) a stack of all the input tensors along the primary dimension;
3363            for definition of "stack", see ``torch.stack()``.
3364            Examples below may better explain the supported output forms.
3365        input_tensor (Tensor): Tensor to be gathered from current rank.
3366            Different from the ``all_gather`` API, the input tensors in this
3367            API must have the same size across all ranks.
3368        group (ProcessGroup, optional): The process group to work on. If None,
3369            the default process group will be used.
3370        async_op (bool, optional): Whether this op should be an async op
3371
3372    Returns:
3373        Async work handle, if async_op is set to True.
3374        None, if not async_op or if not part of the group
3375
3376    Examples:
3377        >>> # xdoctest: +SKIP("need process group init")
3378        >>> # All tensors below are of torch.int64 dtype and on CUDA devices.
3379        >>> # We have two ranks.
3380        >>> device = torch.device(f'cuda:{rank}')
3381        >>> tensor_in = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
3382        >>> tensor_in
3383        tensor([1, 2], device='cuda:0') # Rank 0
3384        tensor([3, 4], device='cuda:1') # Rank 1
3385        >>> # Output in concatenation form
3386        >>> tensor_out = torch.zeros(world_size * 2, dtype=torch.int64, device=device)
3387        >>> dist.all_gather_into_tensor(tensor_out, tensor_in)
3388        >>> tensor_out
3389        tensor([1, 2, 3, 4], device='cuda:0') # Rank 0
3390        tensor([1, 2, 3, 4], device='cuda:1') # Rank 1
3391        >>> # Output in stack form
3392        >>> tensor_out2 = torch.zeros(world_size, 2, dtype=torch.int64, device=device)
3393        >>> dist.all_gather_into_tensor(tensor_out2, tensor_in)
3394        >>> tensor_out2
3395        tensor([[1, 2],
3396                [3, 4]], device='cuda:0') # Rank 0
3397        tensor([[1, 2],
3398                [3, 4]], device='cuda:1') # Rank 1
3399
3400    .. warning::
3401        The Gloo backend does not support this API.
3402
3403    """
3404    _check_single_tensor(input_tensor, "input_tensor")
3405    _check_single_tensor(output_tensor, "output_tensor")
3406    if _rank_not_in_group(group):
3407        _warn_not_in_group("all_gather_into_tensor")
3408        return
3409
3410    output_tensor = (
3411        output_tensor
3412        if not output_tensor.is_complex()
3413        else torch.view_as_real(output_tensor)
3414    )
3415    input_tensor = (
3416        input_tensor
3417        if not input_tensor.is_complex()
3418        else torch.view_as_real(input_tensor)
3419    )
3420
3421    opts = AllgatherOptions()
3422    opts.asyncOp = async_op
3423
3424    group = group or _get_default_group()
3425
3426    if group in _world.pg_coalesce_state.keys():
3427        # We are in coalescing context, do not issue single operation, just append a collective representation
3428        coll = _CollOp(all_gather_into_tensor, input_tensor, output_tensor)
3429        _world.pg_coalesce_state[group].append(coll)
3430        if async_op:
3431            return _IllegalWork()
3432        else:
3433            return None
3434
3435    work = group._allgather_base(output_tensor, input_tensor, opts)
3436
3437    if async_op:
3438        return work
3439    else:
3440        work.wait()
3441
3442
3443@_exception_logger
3444@deprecated(
3445    "`torch.distributed._all_gather_base` is a private function and will be deprecated. "
3446    "Please use `torch.distributed.all_gather_into_tensor` instead.",
3447    category=FutureWarning,
3448)
3449def _all_gather_base(output_tensor, input_tensor, group=None, async_op=False):
3450    """
3451    Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor.
3452
3453    Args:
3454        output_tensor (Tensor): Output tensor. It should contain
3455            correctly-sized tensors to be used for output of the collective.
3456        input_tensor (Tensor): Tensor to be broadcast from current process.
3457        group (ProcessGroup, optional): The process group to work on. If None,
3458            the default process group will be used.
3459        async_op (bool, optional): Whether this op should be an async op
3460
3461    Returns:
3462        Async work handle, if async_op is set to True.
3463        None, if not async_op or if not part of the group
3464
3465    .. warning::
3466        `_all_gather_base` is a private function. Users should use
3467        `all_gather_into_tensor` instead.
3468
3469    """
3470    return all_gather_into_tensor(output_tensor, input_tensor, group, async_op)
3471
3472
3473@_exception_logger
3474@deprecated(
3475    "`torch.distributed.all_gather_coalesced` will be deprecated. If you must use it, "
3476    "please revisit our documentation later at "
3477    "https://pytorch.org/docs/main/distributed.html#collective-functions",
3478    category=FutureWarning,
3479)
3480def all_gather_coalesced(
3481    output_tensor_lists, input_tensor_list, group=None, async_op=False
3482):
3483    """
3484    Gathers input tensors from the whole group in a list in a coalesced manner.
3485
3486    Complex tensors are supported.
3487
3488    Args:
3489        output_tensor_lists (list[list[Tensor]]): Output list. It should contain
3490            correctly-sized tensors to be used for output of the collective.
3491        input_tensor_list (list[Tensor]): Tensors to be broadcast from
3492            current process. At least one tensor has to be non empty.
3493        group (ProcessGroup, optional): The process group to work on. If None,
3494            the default process group will be used.
3495        async_op (bool, optional): Whether this op should be an async op.
3496
3497    Returns:
3498        Async work handle, if async_op is set to True.
3499        None, if not async_op or if not part of the group
3500
3501    Example:
3502        we have 2 process groups, 2 ranks.
3503        rank 0 passes:
3504            input_tensor_list = [[[1, 1], [1, 1]], [2], [3, 3]]
3505            output_tensor_lists =
3506               [[[[-1, -1], [-1, -1]], [-1], [-1, -1]],
3507                [[[-1, -1], [-1, -1]], [-1], [-1, -1]]]
3508        rank 1 passes:
3509            input_tensor_list = [[[3, 3], [3, 3]], [5], [1, 1]]
3510            output_tensor_lists =
3511               [[[[-1, -1], [-1, -1]], [-1], [-1, -1]],
3512                [[[-1, -1], [-1, -1]], [-1], [-1, -1]]]
3513        both rank 0 and 1 get:
3514            output_tensor_lists =
3515               [[[1, 1], [1, 1]], [2], [3, 3]],
3516                [[3, 3], [3, 3]], [5], [1, 1]]].
3517
3518    WARNING: at this time individual shape checking is not implemented across nodes.
3519    For example, if the rank 0 node passes [torch.rand(4), torch.rand(2)] and the
3520    rank 1 node passes [torch.rand(2), torch.rand(2), torch.rand(2)], the
3521    all_gather_coalesced operation will proceed without complaint and return
3522    erroneous outputs. This lack of shape checking results in significant
3523    performance improvements but users of this function should take extra care
3524    to ensure that each node passes in tensors whose shapes match across nodes.
3525    """
3526    # We only check basic compatibility with C++ params here, C++ code will
3527    # do shape and type checking.
3528    if _rank_not_in_group(group):
3529        _warn_not_in_group("all_gather_coalesced")
3530        return
3531    _check_tensor_list(input_tensor_list, "input_tensor_list")
3532    _ensure_all_tensors_same_dtype(input_tensor_list)
3533    if not isinstance(output_tensor_lists, list):
3534        raise TypeError(
3535            "Invalid function argument: output_tensor_lists should be a list"
3536        )
3537    for output_tensor_list in output_tensor_lists:
3538        _check_tensor_list(output_tensor_list, "output_tensor_lists")
3539        _ensure_all_tensors_same_dtype(output_tensor_list)
3540
3541    output_tensor_lists = [
3542        [t if not t.is_complex() else torch.view_as_real(t) for t in l]
3543        for l in output_tensor_lists
3544    ]
3545    input_tensor_list = [
3546        t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list
3547    ]
3548
3549    group = group or _get_default_group()
3550    work = group.allgather_coalesced(output_tensor_lists, input_tensor_list)
3551
3552    if async_op:
3553        return work.get_future()
3554    else:
3555        work.wait()
3556
3557
3558def _validate_output_list_for_rank(my_rank, dst, gather_list):
3559    if dst == my_rank:
3560        if not gather_list:
3561            raise ValueError(
3562                "Argument ``gather_list`` must be specified on destination rank."
3563            )
3564    elif gather_list:
3565        raise ValueError(
3566            "Argument ``gather_list`` must NOT be specified "
3567            "on non-destination ranks."
3568        )
3569
3570
3571@_exception_logger
3572def gather(tensor, gather_list=None, dst=0, group=None, async_op=False):
3573    """
3574    Gathers a list of tensors in a single process.
3575
3576    This function requires all tensors to be the same size on each process.
3577
3578    Args:
3579        tensor (Tensor): Input tensor.
3580        gather_list (list[Tensor], optional): List of appropriately,
3581            same-sized tensors to use for gathered data
3582            (default is None, must be specified on the destination rank)
3583        dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). (default is 0)
3584        group (ProcessGroup, optional): The process group to work on. If None,
3585            the default process group will be used.
3586        async_op (bool, optional): Whether this op should be an async op
3587
3588    Returns:
3589        Async work handle, if async_op is set to True.
3590        None, if not async_op or if not part of the group
3591
3592    """
3593    _check_single_tensor(tensor, "tensor")
3594
3595    # Parameter ``gather_list`` may be left unspecified on non-dst ranks.
3596    if gather_list:
3597        _check_tensor_list(gather_list, "gather_list")
3598    else:
3599        gather_list = []
3600    _ensure_all_tensors_same_dtype(tensor, gather_list)
3601
3602    if _rank_not_in_group(group):
3603        _warn_not_in_group("gather")
3604        return
3605
3606    my_rank = get_rank()
3607    _validate_output_list_for_rank(my_rank, dst, gather_list)
3608    output_tensors = [gather_list] if dst == my_rank else []
3609    input_tensors = [tensor]
3610
3611    opts = GatherOptions()
3612    opts.rootRank = dst
3613
3614    if group is None or group is GroupMember.WORLD:
3615        default_pg = _get_default_group()
3616        work = default_pg.gather(output_tensors, input_tensors, opts)
3617    else:
3618        group_dst_rank = get_group_rank(group, dst)
3619        opts.rootRank = group_dst_rank
3620        work = group.gather(output_tensors, input_tensors, opts)
3621
3622    if async_op:
3623        return work
3624    else:
3625        work.wait()
3626
3627
3628@_exception_logger
3629def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False):
3630    """
3631    Scatters a list of tensors to all processes in a group.
3632
3633    Each process will receive exactly one tensor and store its data in the
3634    ``tensor`` argument.
3635
3636    Complex tensors are supported.
3637
3638    Args:
3639        tensor (Tensor): Output tensor.
3640        scatter_list (list[Tensor]): List of tensors to scatter (default is
3641            None, must be specified on the source rank)
3642        src (int): Source rank on global process group (regardless of ``group`` argument).
3643            Default is 0
3644        group (ProcessGroup, optional): The process group to work on. If None,
3645            the default process group will be used.
3646        async_op (bool, optional): Whether this op should be an async op
3647
3648    Returns:
3649        Async work handle, if async_op is set to True.
3650        None, if not async_op or if not part of the group
3651
3652    .. note:: Note that all Tensors in scatter_list must have the same size.
3653
3654    Example::
3655        >>> # xdoctest: +SKIP("need process group init")
3656        >>> # Note: Process group initialization omitted on each rank.
3657        >>> import torch.distributed as dist
3658        >>> tensor_size = 2
3659        >>> t_ones = torch.ones(tensor_size)
3660        >>> t_fives = torch.ones(tensor_size) * 5
3661        >>> output_tensor = torch.zeros(tensor_size)
3662        >>> if dist.get_rank() == 0:
3663        >>>     # Assumes world_size of 2.
3664        >>>     # Only tensors, all of which must be the same size.
3665        >>>     scatter_list = [t_ones, t_fives]
3666        >>> else:
3667        >>>     scatter_list = None
3668        >>> dist.scatter(output_tensor, scatter_list, src=0)
3669        >>> # Rank i gets scatter_list[i]. For example, on rank 1:
3670        >>> output_tensor
3671        tensor([5., 5.])
3672
3673    """
3674    _check_single_tensor(tensor, "tensor")
3675
3676    # Parameter ``scatter_list`` may be left unspecified on non-src ranks.
3677    if scatter_list:
3678        _check_tensor_list(scatter_list, "scatter_list")
3679    else:
3680        scatter_list = []
3681    _ensure_all_tensors_same_dtype(tensor, scatter_list)
3682
3683    if _rank_not_in_group(group):
3684        _warn_not_in_group("scatter")
3685        return
3686    scatter_list = [
3687        t if not t.is_complex() else torch.view_as_real(t) for t in scatter_list
3688    ]
3689    tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)
3690
3691    my_rank = get_rank()
3692    if src == my_rank:
3693        if not scatter_list:
3694            raise ValueError(
3695                "Argument ``scatter_list`` must be specified on source rank."
3696            )
3697        input_tensors = [scatter_list]
3698        output_tensors = [tensor]
3699    else:
3700        if scatter_list:
3701            raise ValueError(
3702                "Argument ``scatter_list`` must NOT be specified "
3703                "on non-source ranks."
3704            )
3705        input_tensors = []
3706        output_tensors = [tensor]
3707
3708    opts = ScatterOptions()
3709    opts.rootRank = src
3710    opts.asyncOp = async_op
3711
3712    if group is None or group is GroupMember.WORLD:
3713        default_pg = _get_default_group()
3714        work = default_pg.scatter(output_tensors, input_tensors, opts)
3715    else:
3716        group_src_rank = get_group_rank(group, src)
3717        opts.rootRank = group_src_rank
3718        work = group.scatter(output_tensors, input_tensors, opts)
3719
3720    if async_op:
3721        return work
3722    else:
3723        work.wait()
3724
3725
3726@_exception_logger
3727def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
3728    """
3729    Reduces, then scatters a list of tensors to all processes in a group.
3730
3731    Args:
3732        output (Tensor): Output tensor.
3733        input_list (list[Tensor]): List of tensors to reduce and scatter.
3734        op (optional): One of the values from
3735            ``torch.distributed.ReduceOp``
3736            enum.  Specifies an operation used for element-wise reductions.
3737        group (ProcessGroup, optional): The process group to work on. If None,
3738            the default process group will be used.
3739        async_op (bool, optional): Whether this op should be an async op.
3740
3741    Returns:
3742        Async work handle, if async_op is set to True.
3743        None, if not async_op or if not part of the group.
3744
3745    """
3746    _check_single_tensor(output, "output")
3747    _check_tensor_list(input_list, "input_list")
3748    _ensure_all_tensors_same_dtype(output, input_list)
3749    if _rank_not_in_group(group):
3750        _warn_not_in_group("reduce_scatter")
3751        return
3752
3753    opts = ReduceScatterOptions()
3754    opts.reduceOp = op
3755
3756    group = group or _get_default_group()
3757    work = group.reduce_scatter([output], [input_list], opts)
3758
3759    if async_op:
3760        return work
3761    else:
3762        work.wait()
3763
3764
3765@_exception_logger
3766def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=False):
3767    """
3768    Reduces, then scatters a tensor to all ranks in a group.
3769
3770    Args:
3771        output (Tensor): Output tensor. It should have the same size across all
3772            ranks.
3773        input (Tensor): Input tensor to be reduced and scattered. Its size
3774            should be output tensor size times the world size. The input tensor
3775            can have one of the following shapes:
3776            (i) a concatenation of the output tensors along the primary
3777            dimension, or
3778            (ii) a stack of the output tensors along the primary dimension.
3779            For definition of "concatenation", see ``torch.cat()``.
3780            For definition of "stack", see ``torch.stack()``.
3781        group (ProcessGroup, optional): The process group to work on. If None,
3782            the default process group will be used.
3783        async_op (bool, optional): Whether this op should be an async op.
3784
3785    Returns:
3786        Async work handle, if async_op is set to True.
3787        None, if not async_op or if not part of the group.
3788
3789    Examples:
3790        >>> # xdoctest: +SKIP("need process group init")
3791        >>> # All tensors below are of torch.int64 dtype and on CUDA devices.
3792        >>> # We have two ranks.
3793        >>> device = torch.device(f'cuda:{rank}')
3794        >>> tensor_out = torch.zeros(2, dtype=torch.int64, device=device)
3795        >>> # Input in concatenation form
3796        >>> tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=device)
3797        >>> tensor_in
3798        tensor([0, 1, 2, 3], device='cuda:0') # Rank 0
3799        tensor([0, 1, 2, 3], device='cuda:1') # Rank 1
3800        >>> dist.reduce_scatter_tensor(tensor_out, tensor_in)
3801        >>> tensor_out
3802        tensor([0, 2], device='cuda:0') # Rank 0
3803        tensor([4, 6], device='cuda:1') # Rank 1
3804        >>> # Input in stack form
3805        >>> tensor_in = torch.reshape(tensor_in, (world_size, 2))
3806        >>> tensor_in
3807        tensor([[0, 1],
3808                [2, 3]], device='cuda:0') # Rank 0
3809        tensor([[0, 1],
3810                [2, 3]], device='cuda:1') # Rank 1
3811        >>> dist.reduce_scatter_tensor(tensor_out, tensor_in)
3812        >>> tensor_out
3813        tensor([0, 2], device='cuda:0') # Rank 0
3814        tensor([4, 6], device='cuda:1') # Rank 1
3815
3816    .. warning::
3817        The Gloo backend does not support this API.
3818
3819    """
3820    _check_single_tensor(output, "output")
3821    _check_single_tensor(input, "input")
3822
3823    if _rank_not_in_group(group):
3824        _warn_not_in_group("reduce_scatter_tensor")
3825        return
3826
3827    opts = ReduceScatterOptions()
3828    opts.reduceOp = op
3829    opts.asyncOp = async_op
3830
3831    group = group or _get_default_group()
3832
3833    # Check if we are in coalescing context
3834    # If we are, do not issue single operation, just append a collective representation
3835    if group in _world.pg_coalesce_state.keys():
3836        coll = _CollOp(reduce_scatter_tensor, input, output, op, None)
3837        _world.pg_coalesce_state[group].append(coll)
3838        if async_op:
3839            return _IllegalWork()
3840        else:
3841            return None
3842
3843    work = group._reduce_scatter_base(output, input, opts)
3844
3845    if async_op:
3846        return work
3847    else:
3848        work.wait()
3849
3850
3851@deprecated(
3852    "`torch.distributed._reduce_scatter_base` is a private function and will be deprecated. "
3853    "Please use `torch.distributed.reduce_scatter_tensor` instead.",
3854    category=FutureWarning,
3855)
3856def _reduce_scatter_base(output, input, op=ReduceOp.SUM, group=None, async_op=False):
3857    """
3858    Reduces, then scatters a flattened tensor to all processes in a group.
3859
3860    Args:
3861        output (Tensor): Output tensor.
3862        input (Tensor): Input tensor that is of size output tensor size times world size
3863        group (ProcessGroup, optional): The process group to work on. If None,
3864            the default process group will be used.
3865        async_op (bool, optional): Whether this op should be an async op.
3866
3867    Returns:
3868        Async work handle, if async_op is set to True.
3869        None, if not async_op or if not part of the group.
3870
3871    .. warning::
3872        `_reduce_scatter_base` is a private function. Users should use
3873        `reduce_scatter_tensor` instead.
3874
3875    """
3876    return reduce_scatter_tensor(output, input, op, group, async_op)
3877
3878
3879@_exception_logger
3880def all_to_all_single(
3881    output,
3882    input,
3883    output_split_sizes=None,
3884    input_split_sizes=None,
3885    group=None,
3886    async_op=False,
3887):
3888    """
3889    Split input tensor and then scatter the split list to all processes in a group.
3890
3891    Later the received tensors are concatenated from all the processes in the group
3892    and returned as a single output tensor.
3893
3894    Complex tensors are supported.
3895
3896    Args:
3897        output (Tensor): Gathered concatenated output tensor.
3898        input (Tensor): Input tensor to scatter.
3899        output_split_sizes: (list[Int], optional): Output split sizes for dim 0
3900            if specified None or empty, dim 0 of ``output`` tensor must divide
3901            equally by ``world_size``.
3902        input_split_sizes: (list[Int], optional): Input split sizes for dim 0
3903            if specified None or empty, dim 0 of ``input`` tensor must divide
3904            equally by ``world_size``.
3905        group (ProcessGroup, optional): The process group to work on. If None,
3906            the default process group will be used.
3907        async_op (bool, optional): Whether this op should be an async op.
3908
3909    Returns:
3910        Async work handle, if async_op is set to True.
3911        None, if not async_op or if not part of the group.
3912
3913    .. warning::
3914        `all_to_all_single` is experimental and subject to change.
3915
3916    Examples:
3917        >>> # xdoctest: +SKIP("Undefined rank")
3918        >>> input = torch.arange(4) + rank * 4
3919        >>> input
3920        tensor([0, 1, 2, 3])     # Rank 0
3921        tensor([4, 5, 6, 7])     # Rank 1
3922        tensor([8, 9, 10, 11])   # Rank 2
3923        tensor([12, 13, 14, 15]) # Rank 3
3924        >>> output = torch.empty([4], dtype=torch.int64)
3925        >>> dist.all_to_all_single(output, input)
3926        >>> output
3927        tensor([0, 4, 8, 12])    # Rank 0
3928        tensor([1, 5, 9, 13])    # Rank 1
3929        tensor([2, 6, 10, 14])   # Rank 2
3930        tensor([3, 7, 11, 15])   # Rank 3
3931
3932        >>> # Essentially, it is similar to following operation:
3933        >>> scatter_list = list(input.chunk(world_size))
3934        >>> gather_list  = list(output.chunk(world_size))
3935        >>> for i in range(world_size):
3936        >>>     dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i)
3937
3938        >>> # Another example with uneven split
3939        >>> input
3940        tensor([0, 1, 2, 3, 4, 5])                                       # Rank 0
3941        tensor([10, 11, 12, 13, 14, 15, 16, 17, 18])                     # Rank 1
3942        tensor([20, 21, 22, 23, 24])                                     # Rank 2
3943        tensor([30, 31, 32, 33, 34, 35, 36])                             # Rank 3
3944        >>> input_splits
3945        [2, 2, 1, 1]                                                     # Rank 0
3946        [3, 2, 2, 2]                                                     # Rank 1
3947        [2, 1, 1, 1]                                                     # Rank 2
3948        [2, 2, 2, 1]                                                     # Rank 3
3949        >>> output_splits
3950        [2, 3, 2, 2]                                                     # Rank 0
3951        [2, 2, 1, 2]                                                     # Rank 1
3952        [1, 2, 1, 2]                                                     # Rank 2
3953        [1, 2, 1, 1]                                                     # Rank 3
3954        >>> output = ...
3955        >>> dist.all_to_all_single(output, input, output_splits, input_splits)
3956        >>> output
3957        tensor([ 0,  1, 10, 11, 12, 20, 21, 30, 31])                     # Rank 0
3958        tensor([ 2,  3, 13, 14, 22, 32, 33])                             # Rank 1
3959        tensor([ 4, 15, 16, 23, 34, 35])                                 # Rank 2
3960        tensor([ 5, 17, 18, 24, 36])                                     # Rank 3
3961
3962
3963        >>> # Another example with tensors of torch.cfloat type.
3964        >>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j)
3965        >>> input
3966        tensor([1+1j, 2+2j, 3+3j, 4+4j])                                # Rank 0
3967        tensor([5+5j, 6+6j, 7+7j, 8+8j])                                # Rank 1
3968        tensor([9+9j, 10+10j, 11+11j, 12+12j])                          # Rank 2
3969        tensor([13+13j, 14+14j, 15+15j, 16+16j])                        # Rank 3
3970        >>> output = torch.empty([4], dtype=torch.int64)
3971        >>> dist.all_to_all_single(output, input)
3972        >>> output
3973        tensor([1+1j, 5+5j, 9+9j, 13+13j])                              # Rank 0
3974        tensor([2+2j, 6+6j, 10+10j, 14+14j])                            # Rank 1
3975        tensor([3+3j, 7+7j, 11+11j, 15+15j])                            # Rank 2
3976        tensor([4+4j, 8+8j, 12+12j, 16+16j])                            # Rank 3
3977    """
3978    if _rank_not_in_group(group):
3979        _warn_not_in_group("all_to_all_single")
3980        return
3981
3982    opts = AllToAllOptions()
3983    _check_single_tensor(output, "output")
3984    _check_single_tensor(input, "input")
3985    _ensure_all_tensors_same_dtype(output, input)
3986
3987    if input.is_complex():
3988        input = torch.view_as_real(input)
3989    if output.is_complex():
3990        output = torch.view_as_real(output)
3991
3992    output_split_sizes = [] if output_split_sizes is None else output_split_sizes
3993    input_split_sizes = [] if input_split_sizes is None else input_split_sizes
3994
3995    group = group or _get_default_group()
3996    work = group.alltoall_base(
3997        output, input, output_split_sizes, input_split_sizes, opts
3998    )
3999
4000    if async_op:
4001        return work
4002    else:
4003        work.wait()
4004
4005
4006@_exception_logger
4007def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False):
4008    """
4009    Scatters list of input tensors to all processes in a group and return gathered list of tensors in output list.
4010
4011    Complex tensors are supported.
4012
4013    Args:
4014        output_tensor_list (list[Tensor]): List of tensors to be gathered one
4015            per rank.
4016        input_tensor_list (list[Tensor]): List of tensors to scatter one per rank.
4017        group (ProcessGroup, optional): The process group to work on. If None,
4018            the default process group will be used.
4019        async_op (bool, optional): Whether this op should be an async op.
4020
4021    Returns:
4022        Async work handle, if async_op is set to True.
4023        None, if not async_op or if not part of the group.
4024
4025    .. warning::
4026        `all_to_all` is experimental and subject to change.
4027
4028    Examples:
4029        >>> # xdoctest: +SKIP("Undefined rank")
4030        >>> input = torch.arange(4) + rank * 4
4031        >>> input = list(input.chunk(4))
4032        >>> input
4033        [tensor([0]), tensor([1]), tensor([2]), tensor([3])]     # Rank 0
4034        [tensor([4]), tensor([5]), tensor([6]), tensor([7])]     # Rank 1
4035        [tensor([8]), tensor([9]), tensor([10]), tensor([11])]   # Rank 2
4036        [tensor([12]), tensor([13]), tensor([14]), tensor([15])] # Rank 3
4037        >>> output = list(torch.empty([4], dtype=torch.int64).chunk(4))
4038        >>> dist.all_to_all(output, input)
4039        >>> output
4040        [tensor([0]), tensor([4]), tensor([8]), tensor([12])]    # Rank 0
4041        [tensor([1]), tensor([5]), tensor([9]), tensor([13])]    # Rank 1
4042        [tensor([2]), tensor([6]), tensor([10]), tensor([14])]   # Rank 2
4043        [tensor([3]), tensor([7]), tensor([11]), tensor([15])]   # Rank 3
4044
4045        >>> # Essentially, it is similar to following operation:
4046        >>> scatter_list = input
4047        >>> gather_list  = output
4048        >>> for i in range(world_size):
4049        >>>     dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i)
4050
4051        >>> input
4052        tensor([0, 1, 2, 3, 4, 5])                                       # Rank 0
4053        tensor([10, 11, 12, 13, 14, 15, 16, 17, 18])                     # Rank 1
4054        tensor([20, 21, 22, 23, 24])                                     # Rank 2
4055        tensor([30, 31, 32, 33, 34, 35, 36])                             # Rank 3
4056        >>> input_splits
4057        [2, 2, 1, 1]                                                     # Rank 0
4058        [3, 2, 2, 2]                                                     # Rank 1
4059        [2, 1, 1, 1]                                                     # Rank 2
4060        [2, 2, 2, 1]                                                     # Rank 3
4061        >>> output_splits
4062        [2, 3, 2, 2]                                                     # Rank 0
4063        [2, 2, 1, 2]                                                     # Rank 1
4064        [1, 2, 1, 2]                                                     # Rank 2
4065        [1, 2, 1, 1]                                                     # Rank 3
4066        >>> input = list(input.split(input_splits))
4067        >>> input
4068        [tensor([0, 1]), tensor([2, 3]), tensor([4]), tensor([5])]                   # Rank 0
4069        [tensor([10, 11, 12]), tensor([13, 14]), tensor([15, 16]), tensor([17, 18])] # Rank 1
4070        [tensor([20, 21]), tensor([22]), tensor([23]), tensor([24])]                 # Rank 2
4071        [tensor([30, 31]), tensor([32, 33]), tensor([34, 35]), tensor([36])]         # Rank 3
4072        >>> output = ...
4073        >>> dist.all_to_all(output, input)
4074        >>> output
4075        [tensor([0, 1]), tensor([10, 11, 12]), tensor([20, 21]), tensor([30, 31])]   # Rank 0
4076        [tensor([2, 3]), tensor([13, 14]), tensor([22]), tensor([32, 33])]           # Rank 1
4077        [tensor([4]), tensor([15, 16]), tensor([23]), tensor([34, 35])]              # Rank 2
4078        [tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])]                  # Rank 3
4079
4080        >>> # Another example with tensors of torch.cfloat type.
4081        >>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j)
4082        >>> input = list(input.chunk(4))
4083        >>> input
4084        [tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])]            # Rank 0
4085        [tensor([5+5j]), tensor([6+6j]), tensor([7+7j]), tensor([8+8j])]            # Rank 1
4086        [tensor([9+9j]), tensor([10+10j]), tensor([11+11j]), tensor([12+12j])]      # Rank 2
4087        [tensor([13+13j]), tensor([14+14j]), tensor([15+15j]), tensor([16+16j])]    # Rank 3
4088        >>> output = list(torch.empty([4], dtype=torch.int64).chunk(4))
4089        >>> dist.all_to_all(output, input)
4090        >>> output
4091        [tensor([1+1j]), tensor([5+5j]), tensor([9+9j]), tensor([13+13j])]          # Rank 0
4092        [tensor([2+2j]), tensor([6+6j]), tensor([10+10j]), tensor([14+14j])]        # Rank 1
4093        [tensor([3+3j]), tensor([7+7j]), tensor([11+11j]), tensor([15+15j])]        # Rank 2
4094        [tensor([4+4j]), tensor([8+8j]), tensor([12+12j]), tensor([16+16j])]        # Rank 3
4095
4096    """
4097    if _rank_not_in_group(group):
4098        _warn_not_in_group("all_to_all")
4099        return
4100
4101    opts = AllToAllOptions()
4102    _check_tensor_list(output_tensor_list, "output_tensor_list")
4103    _check_tensor_list(input_tensor_list, "input_tensor_list")
4104    _ensure_all_tensors_same_dtype(output_tensor_list, input_tensor_list)
4105
4106    input_tensor_list = [
4107        t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list
4108    ]
4109    output_tensor_list = [
4110        t if not t.is_complex() else torch.view_as_real(t) for t in output_tensor_list
4111    ]
4112
4113    group = group or _get_default_group()
4114    work = group.alltoall(output_tensor_list, input_tensor_list, opts)
4115
4116    if async_op:
4117        return work
4118    else:
4119        work.wait()
4120
4121
4122@_exception_logger
4123def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
4124    """
4125    Synchronize all processes.
4126
4127    This collective blocks processes until the whole group enters this function,
4128    if async_op is False, or if async work handle is called on wait().
4129
4130    Args:
4131        group (ProcessGroup, optional): The process group to work on. If None,
4132            the default process group will be used.
4133        async_op (bool, optional): Whether this op should be an async op
4134        device_ids ([int], optional): List of device/GPU ids.
4135
4136    Returns:
4137        Async work handle, if async_op is set to True.
4138        None, if not async_op or if not part of the group
4139
4140    .. note:: `ProcessGroupNCCL` now relies on stream synchronization instead of
4141              device synchronization to block the CPU. Thus, please do not assume that
4142              `barrier()` would perform a device synchronization.
4143    """
4144    if _rank_not_in_group(group):
4145        _warn_not_in_group("barrier")
4146        return
4147
4148    opts = BarrierOptions()
4149    opts.device = _get_pg_default_device(group)
4150    if device_ids is not None:
4151        if isinstance(device_ids, list):
4152            opts.device_ids = device_ids
4153        else:
4154            raise TypeError(
4155                "Invalid function argument: device_ids type should be List[int]"
4156            )
4157
4158    group = group or _get_default_group()
4159    work = group.barrier(opts=opts)
4160
4161    if async_op:
4162        return work
4163    else:
4164        work.wait()
4165
4166
4167def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=False):
4168    """
4169    Synchronize processes similar to ``torch.distributed.barrier``, but consider a configurable timeout.
4170
4171    It is able to report ranks that did not pass this barrier within the provided timeout.
4172    Specifically, for non-zero ranks, will block until a send/recv is processed from rank 0.
4173    Rank 0 will block until all send /recv from other ranks are processed, and will report
4174    failures for ranks that failed to respond in time. Note that if one rank does not reach the
4175    monitored_barrier (for example due to a hang), all other ranks would fail in monitored_barrier.
4176
4177    This collective will block all processes/ranks in the group, until the
4178    whole group exits the function successfully, making it useful for debugging
4179    and synchronizing. However, it can have a performance impact and should only
4180    be used for debugging or scenarios that require full synchronization points
4181    on the host-side. For debugging purposes, this barrier can be inserted
4182    before the application's collective calls to check if any ranks are
4183    desynchronized.
4184
4185    .. note:: Note that this collective is only supported with the GLOO backend.
4186
4187    Args:
4188        group (ProcessGroup, optional): The process group to work on. If
4189            ``None``, the default process group will be used.
4190        timeout (datetime.timedelta, optional): Timeout for monitored_barrier.
4191            If ``None``, the default process group timeout will be used.
4192        wait_all_ranks (bool, optional): Whether to collect all failed ranks or
4193            not. By default, this is ``False`` and ``monitored_barrier`` on rank 0
4194            will throw on the first failed rank it encounters in order to fail
4195            fast. By setting ``wait_all_ranks=True`` ``monitored_barrier`` will
4196            collect all failed ranks and throw an error containing information
4197            about all failed ranks.
4198
4199    Returns:
4200        ``None``.
4201
4202    Example::
4203        >>> # xdoctest: +SKIP("need process group init")
4204        >>> # Note: Process group initialization omitted on each rank.
4205        >>> import torch.distributed as dist
4206        >>> if dist.get_rank() != 1:
4207        >>>     dist.monitored_barrier() # Raises exception indicating that
4208        >>> # rank 1 did not call into monitored_barrier.
4209        >>> # Example with wait_all_ranks=True
4210        >>> if dist.get_rank() == 0:
4211        >>>     dist.monitored_barrier(wait_all_ranks=True) # Raises exception
4212        >>> # indicating that ranks 1, 2, ... world_size - 1 did not call into
4213        >>> # monitored_barrier.
4214    """
4215    # Need to call rank not in group before using the group, otherwise
4216    # "Invalid process group" error is raised.
4217    if _rank_not_in_group(group):
4218        _warn_not_in_group("monitored_barrier")
4219        return
4220
4221    if get_backend(group) != Backend.GLOO:
4222        raise ValueError("monitored_barrier is only implemented for GLOO backend.")
4223
4224    if timeout is None:
4225        timeout = _get_default_timeout(get_backend(group))
4226    elif isinstance(timeout, float):
4227        # TODO(whc) aparently some existing test case for monitored_barrier passes in a timeout in float format?
4228        warnings.warn(
4229            "Please specify timeout arg as a timedelta. "
4230            f"Converting current value of {timeout} assuming it represents seconds",
4231        )
4232        timeout = timedelta(seconds=timeout)
4233
4234    _check_valid_timeout(timeout)
4235
4236    group_to_use = _get_default_group() if group is None else group
4237    return group_to_use.monitored_barrier(timeout, wait_all_ranks=wait_all_ranks)
4238
4239
4240def _create_process_group_wrapper(
4241    wrapped_pg: torch._C._distributed_c10d.Backend,
4242    store_prefix: str,
4243    store: Store,
4244    rank: int,
4245    world_size: int,
4246    timeout: timedelta = default_pg_timeout,
4247):
4248    assert _GLOO_AVAILABLE, "ProcessGroupWrapper unsupported without GLOO backend."
4249
4250    # (whc) this appears to be just for the gloo backend? if so, `default_pg_timeout` is appropriate...
4251
4252    # Create a separate prefix store for the helper process group.
4253    prefix = f"{PG_WRAPPER_STORE_PREFIX}:{store_prefix}"
4254    store = PrefixStore(prefix, store)
4255    helper_pg = ProcessGroupGloo(store, rank, world_size, timeout=timeout)
4256    # Wrap the underlying pg with ProcessGroupWrapper.
4257    wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg)
4258    return wrapped_pg
4259
4260
4261# helper function for deterministically hashing a list of ranks
4262def _hash_ranks(ranks: List[int]):
4263    return hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest()
4264
4265
4266# Takes a list of ranks and computes an integer color
4267def _process_group_color(ranks: List[int]) -> int:
4268    # Convert our hash to an int, but avoid negative numbers by shifting a bit.
4269    return int(_hash_ranks(ranks), 16) % (sys.maxsize >> 1)
4270
4271
4272def _process_group_name(ranks, use_hashed_name):
4273    global _world
4274    if use_hashed_name:
4275        pg_name = _hash_ranks(ranks)
4276        while pg_name in _world.pg_names.values():
4277            pg_name = hashlib.sha1(bytes(pg_name + "_", "utf-8")).hexdigest()
4278    else:
4279        pg_name = str(_world.group_count)
4280        _world.group_count += 1
4281    return pg_name
4282
4283
4284def _get_backend_from_str(backend: Optional[str] = None) -> Backend:
4285    # Default to the same backend as the global process group
4286    #  if backend is not specified.
4287    if not backend:
4288        backend = get_backend(_get_default_group())
4289    return Backend(backend)
4290
4291
4292def _is_safe_to_split() -> bool:
4293    """
4294    Checks if it is safe to split the any process group in the world.
4295    This is only safe if the default pg has a bound device id, otherwise
4296    users must be aware that a pg is only splittable after the first collective is
4297    issued.
4298    """
4299    return False if _get_default_group().bound_device_id is None else True
4300
4301
4302@_time_logger
4303def split_group(
4304    parent_pg: Optional[ProcessGroup] = None,
4305    split_ranks: Optional[list] = None,
4306    timeout: Optional[timedelta] = None,
4307    pg_options: Optional[Any] = None,
4308    group_desc: Optional[str] = None,
4309) -> Optional[ProcessGroup]:
4310    """
4311    Create a new process group splitted from the given parent process group.
4312
4313    warning:: This is an experimental API and only the ``NCCL`` backend supports this API.
4314    Other backends will raise an error.
4315    Users of this API must gurantee that all ranks in the parent group enter this API call,
4316    and the split of the sub groups is the same accross all ranks in the parent group.
4317
4318    Args:
4319        parent_pg (ProcessGroup, optional): The parent process group. If None,
4320            the default process group will be used. Users need to gurantee that
4321            the parent group is fully initialized (e.g, communicators are initialized)
4322        split_ranks (list[list[int]]): the split ranks, which is a list of list of ranks.
4323            Users need to make sure the validity of the split ranks such that one
4324            split (represented by one inner list of ints) does not overlap with any other split.
4325            Note that the ranks in each split is the group rank (instead of global rank)
4326            in the parent pg. For example, if the parent group has 4 ranks, and split_ranks can be
4327            [[0, 1], [2, 3]]. Note [[0,1]] is also a valid split, in which case ranks 2, 3 would
4328            return a non-group member.
4329        timeout (timedelta, optional): see `init_process_group` for details and default value.
4330        pg_options (ProcessGroupOptions, optional): only ProcessGroupNCCLOptions is supported now.
4331            specifying what additional options need to be passed in during
4332            the construction of specific process groups. i.e.``is_high_priority_stream``
4333            can be specified so that process group can pick up high priority cuda streams.
4334            For other availble options to config nccl,
4335            See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
4336        group_desc (str, optional): a string to describe the process group.
4337
4338    Returns:
4339        ProcessGroup if the current rank is within one split/subgroup given by split_ranks,
4340        or None if the current rank is not part of any split_ranks`.
4341
4342    """
4343    # check inputs
4344    if split_ranks is None:
4345        raise ValueError("split_ranks cannot be None")
4346
4347    global _world
4348    default_pg = _get_default_group()
4349    device_id = default_pg.bound_device_id
4350    if not device_id:
4351        raise RuntimeError(
4352            "No device associated with the default pg, not safe to split any process groups"
4353        )
4354    default_backend, default_store = _world.pg_map[default_pg]
4355    global_rank = default_pg.rank()
4356    global_world_size = default_pg.size()
4357
4358    if not parent_pg:
4359        parent_pg = default_pg
4360    if parent_pg not in _world.pg_group_ranks:
4361        raise ValueError(f"Group {parent_pg} is not registered")
4362
4363    parent_global_to_group_ranks = _world.pg_group_ranks[parent_pg]
4364    parent_group_to_global_ranks = {
4365        group_rank: global_rank
4366        for global_rank, group_rank in parent_global_to_group_ranks.items()
4367    }
4368
4369    if global_rank not in parent_global_to_group_ranks:
4370        raise ValueError(
4371            f"Global rank {global_rank} is not part of the parent group {parent_pg}"
4372        )
4373
4374    parent_group_rank = parent_global_to_group_ranks[global_rank]
4375    parent_backend = parent_pg._get_backend(torch.device("cuda"))
4376
4377    # if the parent backend does not support splitting, raise error
4378    # currently this API only support NCCL backend
4379    if (
4380        not parent_backend
4381        or not parent_backend.supports_splitting
4382        or not isinstance(parent_backend, ProcessGroupNCCL)
4383    ):
4384        raise RuntimeError(
4385            "No backend for the parent process group or its backend does not support splitting"
4386        )
4387
4388    # set the group_desc before the color or no_cloor split
4389    group_desc = (
4390        f"{parent_pg.group_desc}:split:{parent_backend.comm_split_count()}"
4391        if group_desc is None
4392        else group_desc
4393    )
4394
4395    parent_backend_str, _ = _world.pg_map[parent_pg]
4396    # same type of backend as the parent process group
4397    backend = Backend(parent_backend_str)
4398    backend_config = BackendConfig(backend)
4399
4400    if pg_options is not None:
4401        assert isinstance(
4402            pg_options, ProcessGroupNCCL.Options
4403        ), "Expected pg_options argument to be of type ProcessGroupNCCL.Options"
4404    else:
4405        # default pg_options same as the parent process group
4406        pg_options = parent_backend.options
4407
4408    # this timeout defaulting/validation is used for all the new_groups/new_subgroups variants,
4409    # which may just pass their timeout value (or None)
4410    if timeout is None:
4411        timeout = _get_default_timeout(backend)
4412    _check_valid_timeout(timeout)
4413
4414    # find my group of ranks and my group local rank in split_ranks
4415    my_group = None
4416    group_rank = -1
4417
4418    for split_group in split_ranks:
4419        if len(split_group) == 0:
4420            raise ValueError("the split group cannot be empty")
4421        if len(split_group) > global_world_size:
4422            raise ValueError(
4423                "the split group's size should be less or equal to the world_size set by init_process_group"
4424            )
4425        if len(split_group) != len(set(split_group)):
4426            raise ValueError("the split group cannot have duplicate ranks")
4427        split_group = sorted(split_group)
4428        if parent_group_rank in split_group:
4429            my_group = split_group
4430            group_rank = split_group.index(parent_group_rank)
4431            break
4432    # if my rank does not belong to any sub group,
4433    # no_color split should be called
4434    if my_group is None or group_rank == -1:
4435        parent_backend.perform_nocolor_split(device_id)
4436        return None
4437
4438    group_name = _process_group_name(my_group, use_hashed_name=False)
4439    global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group]
4440
4441    prefix_store = PrefixStore(f"{group_name}/", default_store)
4442    base_pg_options = ProcessGroup.Options(backend=str(backend))
4443    base_pg_options._timeout = timeout
4444    pg: ProcessGroup = ProcessGroup(
4445        prefix_store, group_rank, len(my_group), base_pg_options
4446    )
4447    pg.bound_device_id = device_id
4448
4449    pg_options._timeout = timeout
4450    pg_options.split_from = parent_backend
4451    pg_options.split_color = _process_group_color(my_group)
4452    pg_options.global_ranks_in_group = global_ranks_in_my_group
4453    pg_options.group_name = group_name
4454    backend_class = ProcessGroupNCCL(
4455        prefix_store, group_rank, len(my_group), pg_options
4456    )
4457    backend_type = ProcessGroup.BackendType.NCCL
4458    backend_class._set_sequence_number_for_group()
4459
4460    pg._register_backend(torch.device("cuda"), backend_type, backend_class)
4461
4462    # set group_name and group_desc to backend
4463    assert group_name is not None
4464    assert group_desc is not None
4465    pg._set_group_name(group_name)
4466    pg._set_group_desc(group_desc)
4467
4468    # always eagerly initialize the backend in split_group
4469    eager_backend = pg._get_backend(device_id)
4470    eager_backend.eager_connect_single_device(device_id)
4471
4472    # update global state
4473    _world.pg_map[pg] = (backend, prefix_store)
4474    _world.pg_names[pg] = group_name
4475    _register_process_group(group_name, pg)
4476    _world.pg_backend_config[pg] = str(backend_config)
4477    pg_tag = f"ptd:{group_name}"
4478    _world.tags_to_pg.setdefault(pg_tag, []).append(pg)
4479    _world.pg_to_tag[pg] = pg_tag
4480
4481    # Create the global rank to group rank mapping
4482    _world.pg_group_ranks[pg] = {
4483        global_rank: group_rank
4484        for group_rank, global_rank in enumerate(global_ranks_in_my_group)
4485    }
4486
4487    return pg
4488
4489
4490@_time_logger
4491def new_group(
4492    ranks=None,
4493    timeout=None,
4494    backend=None,
4495    pg_options=None,
4496    use_local_synchronization=False,
4497    group_desc=None,
4498):
4499    """
4500    Create a new distributed group.
4501
4502    This function requires that all processes in the main group (i.e. all
4503    processes that are part of the distributed job) enter this function, even
4504    if they are not going to be members of the group. Additionally, groups
4505    should be created in the same order in all processes.
4506
4507    .. warning::
4508        Safe concurrent usage:
4509        When using multiple process groups with the ``NCCL`` backend, the user
4510        must ensure a globally consistent execution order of collectives across
4511        ranks.
4512
4513        If multiple threads within a process issue collectives, explicit
4514        synchronization is necessary to ensure consistent ordering.
4515
4516        When using async variants of torch.distributed communication APIs,
4517        a work object is returned and the communication kernel is
4518        enqueued on a separate CUDA stream, allowing overlap of communication
4519        and computation. Once one or more async ops have been issued on one process
4520        group, they must be synchronized with other cuda streams by calling `work.wait()`
4521        before using another process group.
4522
4523        See `Using multiple NCCL communicators concurrently <https://docs.nvid
4524        ia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#using
4525        -multiple-nccl-communicators-concurrently>`_ for more details.
4526
4527    Args:
4528        ranks (list[int]): List of ranks of group members. If ``None``, will be
4529            set to all ranks. Default is ``None``.
4530        timeout (timedelta, optional): see `init_process_group` for details and default value.
4531        backend (str or Backend, optional): The backend to use. Depending on
4532            build-time configurations, valid values are ``gloo`` and ``nccl``.
4533            By default uses the same backend as the global group. This field
4534            should be given as a lowercase string (e.g., ``"gloo"``), which can
4535            also be accessed via :class:`Backend` attributes (e.g.,
4536            ``Backend.GLOO``). If ``None`` is passed in, the backend
4537            corresponding to the default process group will be used. Default is
4538            ``None``.
4539        pg_options (ProcessGroupOptions, optional): process group options
4540            specifying what additional options need to be passed in during
4541            the construction of specific process groups. i.e. for the ``nccl``
4542            backend, ``is_high_priority_stream`` can be specified so that
4543            process group can pick up high priority cuda streams. For other availble options to config nccl,
4544            See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
4545        use_local_synchronization (bool, optional): perform a group-local
4546            barrier at the end of the process group creation. This is different
4547            in that non-member ranks don't need to call into API and don't
4548            join the barrier.
4549        group_desc (str, optional): a string to describe the process group.
4550
4551    Returns:
4552        A handle of distributed group that can be given to collective calls or
4553        GroupMember.NON_GROUP_MEMBER if the rank is not part of ``ranks``.
4554
4555    N.B. use_local_synchronization doesn't work with MPI.
4556
4557    N.B. While use_local_synchronization=True can be significantly faster with larger
4558    clusters and small process groups, care must be taken since it changes cluster behavior
4559    as non-member ranks don't join the group barrier().
4560
4561    N.B. use_local_synchronization=True can lead to deadlocks when each rank creates
4562    multiple overlaping process groups. To avoid that, make sure all ranks follow the
4563    same global creation order.
4564    """
4565    return _new_group_with_tag(
4566        ranks,
4567        timeout,
4568        backend,
4569        pg_options,
4570        None,
4571        use_local_synchronization=use_local_synchronization,
4572        group_desc=group_desc,
4573    )
4574
4575
4576def _new_group_with_tag(
4577    ranks=None,
4578    timeout=None,
4579    backend=None,
4580    pg_options=None,
4581    pg_tag=None,
4582    use_local_synchronization=False,
4583    group_desc=None,
4584):
4585    """
4586    Variant of ``new_group`` that exposes tag creation.
4587
4588    :: N.B. The mechanism is experimental and tied to the functional collectives effort, see
4589    ``torch.distributed._functional_collectives`` for reference on how to use it.
4590    """
4591    global _world
4592
4593    default_pg = _get_default_group()
4594    device_id = default_pg.bound_device_id
4595    default_backend, default_store = _world.pg_map[default_pg]
4596    global_rank = default_pg.rank()
4597    global_world_size = default_pg.size()
4598
4599    # Default to the same backend as the global process group
4600    # if the backend is not specified.
4601    if not backend:
4602        backend = default_backend
4603    backend = Backend(backend)
4604
4605    # this timeout defaulting/validation is used for all the new_groups/new_subgroups variants,
4606    # which may just pass their timeout value (or None)
4607    if timeout is None:
4608        timeout = _get_default_timeout(backend)
4609    _check_valid_timeout(timeout)
4610
4611    if use_local_synchronization:
4612        # MPI backend doesn't have have a way for us to perform a partial sync
4613        if backend == Backend.MPI:
4614            raise ValueError(
4615                "MPI backend doesn't support use_local_synchronization=True"
4616            )
4617        if ranks is not None and get_rank() not in ranks:
4618            return None
4619
4620    # checks the input ranks
4621    if ranks is not None:
4622        ranks = sorted(ranks)
4623        group_world_size = len(ranks)
4624        if group_world_size > global_world_size:
4625            raise ValueError(
4626                "the new group's world size should be less or "
4627                "equal to the world size set by "
4628                "init_process_group"
4629            )
4630        # check ranks' sanity
4631        for rank in ranks:
4632            if rank < 0 or rank >= global_world_size:
4633                raise ValueError(
4634                    "The new group's rank should be within "
4635                    "the world_size set by init_process_group"
4636                )
4637        if global_rank in ranks:
4638            group_rank = ranks.index(global_rank)
4639        else:
4640            group_rank = None
4641    else:
4642        ranks = list(range(global_world_size))
4643        group_world_size = global_world_size
4644        group_rank = global_rank
4645
4646    group_name = _process_group_name(ranks, use_hashed_name=use_local_synchronization)
4647
4648    pg, pg_store = _new_process_group_helper(
4649        group_world_size,
4650        group_rank,
4651        ranks,
4652        backend,
4653        default_store,
4654        group_name,
4655        pg_options=pg_options,
4656        timeout=timeout,
4657        pg_tag=pg_tag,
4658        device_id=device_id,
4659        group_desc=group_desc,
4660    )
4661
4662    # Create the global rank to group rank mapping
4663    _world.pg_group_ranks[pg] = {
4664        global_rank: group_rank for group_rank, global_rank in enumerate(ranks)
4665    }
4666
4667    if _is_barrier_after_init() == 1:
4668        # barrier at the end to ensure that once we return from this method, all
4669        # process groups including global variables (if any) are updated
4670        # correctly on all ranks.
4671        # Update 04/2023: for large-scale runs, this barrier (esp. store-based
4672        # barrier) may be costly and/or unscalable. Also, in a lot of cases,
4673        # these barriers may be unnecessary, as proven by a green CI after
4674        # removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been
4675        # added which enables this barrier only when set to 1.
4676        logger.info(
4677            "Performing barrier after ProcessGroup initialization since "
4678            "TORCH_DIST_INIT_BARRIER = 1"
4679        )
4680        if backend == Backend.MPI:
4681            # MPI doesn't have store.
4682            barrier()
4683        else:
4684            barrier_store = pg_store if use_local_synchronization else default_store
4685            world_size = len(ranks) if use_local_synchronization else get_world_size()
4686            # Use store based barrier here since barrier() used a bunch of
4687            # default devices and messes up NCCL internal state.
4688            _store_based_barrier(
4689                global_rank, barrier_store, group_name, world_size, timeout
4690            )
4691
4692    return pg
4693
4694
4695def new_subgroups(
4696    group_size=None,
4697    group=None,
4698    timeout=None,
4699    backend=None,
4700    pg_options=None,
4701    group_desc=None,
4702):
4703    """
4704    Create subgroups of equal size.
4705
4706    By default, it creates intra-machine subgroups,
4707    where each of which contains all the ranks of a machine, based on the assumption
4708    that each machine has the same number of devices.
4709
4710    This is a convenience API that calls ``new_group`` to generate multiple subgroups.
4711    It requires that all processes in the main group (i.e. all
4712    processes that are part of the distributed job) enter this function, even
4713    if they are not going to be members of the group.
4714
4715    .. warning::
4716        If ``group_size`` is passed in, the world size must be divisible by ``group_size``.
4717        If no ``group_size`` is passed in, it believe that you are creating a group based
4718        on CUDA and determining the group size by number of CUDA devices, and if not all
4719        the machines have the same number of devices, the subgroup division will be
4720        different across nodes and can cause unexpected behaviors. Therefore, if you are
4721        creating a subgroup that does not depend on CUDA (such as Gloo on CPU), please
4722        pass in ``group_size`` correctly.
4723
4724    .. warning::
4725        See warning `Safe concurrent usage` for `new_group` API for important details about
4726        using multiple process groups concurrently in a safe manner.
4727
4728    Args:
4729        group_size (int, optional): The size of each subgroup. If ``None``,
4730            the default subgroup size is equal to the number of devices on each machine,
4731            based on the assumption that each machine has exactly the same
4732            number of devices. Default is ``None``.
4733        timeout (timedelta, optional): see `init_process_group` for details and default value.
4734        backend (str or Backend, optional): The backend to use. Depending on
4735            build-time configurations, valid values are ``gloo`` and ``nccl``.
4736            By default uses the same backend as the global group. This field
4737            should be given as a lowercase string (e.g., ``"gloo"``), which can
4738            also be accessed via :class:`Backend` attributes (e.g.,
4739            ``Backend.GLOO``). If ``None`` is passed in, the backend
4740            corresponding to the default process group will be used. Default is
4741            ``None``.
4742        pg_options (ProcessGroupOptions, optional): process group options
4743            specifying what additional options need to be passed in during
4744            the construction of specific process groups. i.e. for the ``nccl``
4745            backend, ``is_high_priority_stream`` can be specified so that
4746            process group can pick up high priority cuda streams.
4747        group_desc (str, optional): A string describing the group. Each subgroup will
4748            inherit its group_desc
4749
4750    Returns:
4751        The subgroup containing the current rank, and all the subgroups used for cleanup.
4752
4753    Examples:
4754        >>> # Create intra-machine subgroups.
4755        >>> # xdoctest: +SKIP("need process group init")
4756        >>> cur_subgroup, subgroups = dist.new_subgroups()
4757        >>> # Allreduce within the machine.
4758        >>> rank = dist.get_rank()
4759        >>> tensor = torch.ones(1, device=rank) * rank
4760        >>> dist.all_reduce(tensor, group=cur_subgroup)
4761        >>> tensor
4762        tensor([28])  # Assume 8 CUDA devices per machine.  28 is sum(range(8)).
4763        >>> # Cleanup.
4764        >>> for subgroup in subgroups:
4765        >>>     dist.destroy_process_group(subgroup)
4766    """
4767    if group_size is None:
4768        if not torch.cuda.is_available():
4769            raise ValueError(
4770                "Default group size only takes effect when CUDA is available."
4771                "If your subgroup using a backend that does not depend on CUDA,"
4772                "please pass in 'group_size' correctly."
4773            )
4774        group_size = torch.cuda.device_count()
4775    if group_size <= 0:
4776        raise ValueError(f"The arg 'group_size' ({group_size}) must be positive")
4777
4778    world_size = get_world_size()
4779    if world_size < group_size:
4780        raise ValueError(
4781            f"The arg 'group_size' ({group_size}) must not exceed the world size ({world_size})"
4782        )
4783    if world_size % group_size != 0:
4784        raise ValueError("The world size must be divisible by 'group_size'")
4785
4786    subgroups = []
4787    cur_subgroup = None
4788
4789    for subgroup_id in range(world_size // group_size):
4790        start_rank = subgroup_id * group_size
4791        end_rank = start_rank + group_size
4792        ranks_in_subgroup = list(range(start_rank, end_rank))
4793        subgroup = new_group(
4794            ranks=ranks_in_subgroup,
4795            timeout=timeout,
4796            backend=backend,
4797            pg_options=pg_options,
4798            group_desc=group_desc,
4799        )
4800        subgroups.append(subgroup)
4801
4802        rank = get_rank()
4803        if rank in ranks_in_subgroup:
4804            cur_subgroup = subgroup
4805            logger.info("Rank %s is assigned to subgroup %s", rank, ranks_in_subgroup)
4806
4807    return cur_subgroup, subgroups
4808
4809
4810def new_subgroups_by_enumeration(
4811    ranks_per_subgroup_list,
4812    timeout=None,
4813    backend=None,
4814    pg_options=None,
4815    group_desc=None,
4816):
4817    """
4818    Create subgroups by dividing the global world.
4819
4820    The division is specified by a nested list of ranks. The subgroups cannot have
4821    overlap, and some ranks may not have to be in any subgroup.
4822
4823    This is a convenience API that calls ``new_group`` to generate multiple subgroups.
4824    It requires that all processes in the main group (i.e. all
4825    processes that are part of the distributed job) enter this function, even
4826    if they are not going to be members of the group.
4827
4828    .. warning::
4829        See warning `Safe concurrent usage` for `new_group` API for important details about
4830        using multiple process groups concurrently in a safe manner.
4831
4832    Args:
4833        ranks_per_subgroup_list (list[list[int]]): A nested list of ranks of
4834            group members.
4835        timeout (timedelta, optional): see `init_process_group` for details and default value.
4836        backend (str or Backend, optional): The backend to use. Depending on
4837             build-time configurations, valid values are ``gloo`` and ``nccl``.
4838             By default uses the same backend as the global group. This field
4839             should be given as a lowercase string (e.g., ``"gloo"``), which can
4840             also be accessed via :class:`Backend` attributes (e.g.,
4841             ``Backend.GLOO``). If ``None`` is passed in, the backend
4842             corresponding to the default process group will be used. Default is
4843             ``None``.
4844        pg_options (ProcessGroupOptions, optional): process group options
4845            specifying what additional options need to be passed in during
4846            the construction of specific process groups. i.e. for the ``nccl``
4847            backend, ``is_high_priority_stream`` can be specified so that
4848            process group can pick up high priority cuda streams.
4849        group_desc (str, optional): A string describing the group. Each subgroup will
4850            inherit its group_desc.
4851
4852    Returns:
4853        The subgroup containing the current rank, and all the subgroups used for cleanup.
4854
4855    Examples:
4856        >>> # Create two subgroups, where each has 2 processes.
4857        >>> # xdoctest: +SKIP("need process group init")
4858        >>> cur_subgroup, subgroups = dist.new_subgroups(ranks=[[0, 2], [1, 3]])
4859        >>> rank = dist.get_rank()
4860        >>> tensor = torch.ones(1, device=rank) * rank
4861        >>> dist.all_reduce(tensor, group=cur_subgroup)
4862        >>> tensor
4863        tensor([2])     # Subgroup 0: ranks 0 and 2
4864        tensor([4])     # Subgroup 1: ranks 1 and 3
4865    """
4866    if ranks_per_subgroup_list is None or len(ranks_per_subgroup_list) == 0:
4867        raise ValueError("The arg 'ranks_per_subgroup_list' cannot be empty")
4868
4869    subgroups = []
4870    cur_subgroup = None
4871    # Create a mapping from rank to subgroup to check if there is any subgroup overlap.
4872    rank_to_ranks_dict = {}  # type: ignore[var-annotated]
4873    for ranks in ranks_per_subgroup_list:
4874        subgroup = new_group(
4875            ranks=ranks,
4876            timeout=timeout,
4877            backend=backend,
4878            pg_options=pg_options,
4879            group_desc=group_desc,
4880        )
4881        subgroups.append(subgroup)
4882        my_rank = get_rank()
4883        for rank in ranks:
4884            if rank in rank_to_ranks_dict:
4885                raise ValueError(
4886                    f"Rank {rank} has appeared in both subgroup {rank_to_ranks_dict[rank]} and {ranks}"
4887                )
4888            rank_to_ranks_dict[rank] = ranks
4889            if my_rank == rank:
4890                cur_subgroup = subgroup
4891                logger.info("Rank %s is assigned to subgroup %s", rank, ranks)
4892
4893    return cur_subgroup, subgroups
4894
4895
4896def _find_pg_by_ranks_and_tag(tag: str, ranks: List[int]) -> Optional[ProcessGroup]:
4897    if len(tag) > 0 and not tag.startswith("ptd:") and not tag.startswith("user:"):
4898        tag = f"user:{tag}"
4899
4900    for group in _world.tags_to_pg.get(tag, []):
4901        if group.size() != len(ranks):
4902            continue
4903
4904        group_ranks = get_process_group_ranks(group)
4905        good = all(r in group_ranks for r in ranks)
4906        if good:
4907            return group
4908    return None
4909
4910
4911def _find_or_create_pg_by_ranks_and_tag(
4912    tag: str, ranks: List[int], stride: int
4913) -> ProcessGroup:
4914    assert (
4915        len(ranks) % stride == 0
4916    ), f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
4917
4918    my_rank = get_rank()
4919    my_ranks = None
4920
4921    if stride == len(ranks):
4922        my_ranks = ranks.copy()
4923        assert my_rank in my_ranks, "rankset doesn't include the current node"
4924    else:
4925        for i in range(0, len(ranks), stride):
4926            rank_set = ranks[i : i + stride]
4927            if my_rank in rank_set:
4928                my_ranks = rank_set
4929        assert my_ranks is not None, "rankset doesn't include the current node"
4930
4931    my_ranks = sorted(my_ranks)
4932
4933    pg = _find_pg_by_ranks_and_tag(tag, my_ranks)
4934    if pg is not None:
4935        return pg
4936    if tag == "":
4937        raise ValueError("Cannot automatically create PG with empty tag")
4938    # TODO copy settings and timeout from default PG
4939    return _new_group_with_tag(my_ranks, pg_tag=tag)
4940
4941
4942def _get_group_tag(pg: ProcessGroup) -> str:
4943    """Return the tag associated with ``pg``."""
4944    tag = _world.pg_to_tag[pg]
4945    if tag.startswith("user:"):
4946        tag = tag[5:]
4947    return tag
4948
4949
4950def _get_process_group_name(pg: ProcessGroup) -> str:
4951    return _world.pg_names.get(pg, "None")
4952
4953
4954def _get_process_group_store(pg: ProcessGroup) -> Store:
4955    return _world.pg_map[pg][1]
4956
4957
4958# This ops are not friendly to TorchDynamo. So, we decide to disallow these ops
4959# in FX graph, allowing them to run them on eager, with torch.compile.
4960dynamo_unsupported_distributed_c10d_ops = [
4961    recv,
4962    all_gather_object,
4963    all_gather_coalesced,
4964    all_to_all_single,
4965    all_reduce,
4966    gather_object,
4967    all_to_all,
4968    all_reduce_coalesced,
4969    gather,
4970    send_object_list,
4971    recv_object_list,
4972    broadcast_object_list,
4973    barrier,
4974    scatter,
4975    scatter_object_list,
4976    reduce,
4977    all_gather,
4978    reduce_scatter,
4979    all_gather_into_tensor,
4980    broadcast,
4981    reduce_scatter_tensor,
4982    send,
4983]
4984