xref: /aosp_15_r20/external/pytorch/torch/distributed/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import logging
3import pdb
4import sys
5import traceback
6import typing
7
8import torch
9
10
11log = logging.getLogger(__name__)
12
13
14def is_available() -> bool:
15    """
16    Return ``True`` if the distributed package is available.
17
18    Otherwise,
19    ``torch.distributed`` does not expose any other APIs. Currently,
20    ``torch.distributed`` is available on Linux, MacOS and Windows. Set
21    ``USE_DISTRIBUTED=1`` to enable it when building PyTorch from source.
22    Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and Windows,
23    ``USE_DISTRIBUTED=0`` for MacOS.
24    """
25    return hasattr(torch._C, "_c10d_init")
26
27
28if is_available() and not torch._C._c10d_init():
29    raise RuntimeError("Failed to initialize torch.distributed")
30
31# Custom Runtime Errors thrown from the distributed package
32DistError = torch._C._DistError
33DistBackendError = torch._C._DistBackendError
34DistNetworkError = torch._C._DistNetworkError
35DistStoreError = torch._C._DistStoreError
36
37if is_available():
38    from torch._C._distributed_c10d import (
39        _broadcast_coalesced,
40        _compute_bucket_assignment_by_size,
41        _ControlCollectives,
42        _DEFAULT_FIRST_BUCKET_BYTES,
43        _make_nccl_premul_sum,
44        _register_builtin_comm_hook,
45        _register_comm_hook,
46        _StoreCollectives,
47        _test_python_store,
48        _verify_params_across_processes,
49        Backend as _Backend,
50        BuiltinCommHookType,
51        DebugLevel,
52        FileStore,
53        get_debug_level,
54        GradBucket,
55        Logger,
56        PrefixStore,
57        ProcessGroup as ProcessGroup,
58        Reducer,
59        set_debug_level,
60        set_debug_level_from_env,
61        Store,
62        TCPStore,
63        Work as _Work,
64    )
65
66    class _DistributedPdb(pdb.Pdb):
67        """
68        Supports using PDB from inside a multiprocessing child process.
69
70        Usage:
71        _DistributedPdb().set_trace()
72        """
73
74        def interaction(self, *args, **kwargs):
75            _stdin = sys.stdin
76            try:
77                sys.stdin = open("/dev/stdin")
78                pdb.Pdb.interaction(self, *args, **kwargs)
79            finally:
80                sys.stdin = _stdin
81
82    _breakpoint_cache: typing.Dict[int, typing.Any] = {}
83
84    def breakpoint(rank: int = 0, skip: int = 0):
85        """
86        Set a breakpoint, but only on a single rank.  All other ranks will wait for you to be
87        done with the breakpoint before continuing.
88
89        Args:
90            rank (int): Which rank to break on.  Default: ``0``
91            skip (int): Skip the first ``skip`` calls to this breakpoint. Default: ``0``.
92        """
93        if skip > 0:
94            key = hash(str(traceback.format_exc()))
95            counter = _breakpoint_cache.get(key, 0) + 1
96            _breakpoint_cache[key] = counter
97            if counter <= skip:
98                log.warning("Skip the breakpoint, counter=%d", counter)
99                return
100
101        if get_rank() == rank:
102            pdb = _DistributedPdb()
103            pdb.message(
104                "\n!!! ATTENTION !!!\n\n"
105                f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n"
106            )
107            pdb.set_trace()
108        # If Meta/Python keys are in the TLS, we want to make sure that we ignore them
109        # and hit the (default) CPU/CUDA implementation of barrier.
110        meta_in_tls = torch._C._meta_in_tls_dispatch_include()
111        guard = torch._C._DisableTorchDispatch()  # type: ignore[attr-defined]
112        torch._C._set_meta_in_tls_dispatch_include(False)
113        try:
114            barrier()
115        finally:
116            torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)
117            del guard
118
119    if sys.platform != "win32":
120        from torch._C._distributed_c10d import HashStore
121
122    from .device_mesh import DeviceMesh, init_device_mesh
123
124    # Variables prefixed with underscore are not auto imported
125    # See the comment in `distributed_c10d.py` above `_backend` on why we expose
126    # this.
127    from .distributed_c10d import *  # noqa: F403
128    from .distributed_c10d import (
129        _all_gather_base,
130        _coalescing_manager,
131        _CoalescingManager,
132        _create_process_group_wrapper,
133        _get_process_group_name,
134        _rank_not_in_group,
135        _reduce_scatter_base,
136        get_node_local_rank,
137    )
138    from .remote_device import _remote_device
139    from .rendezvous import (
140        _create_store_from_options,
141        register_rendezvous_handler,
142        rendezvous,
143    )
144
145    set_debug_level_from_env()
146
147else:
148    # This stub is sufficient to get
149    #   python test/test_public_bindings.py -k test_correct_module_names
150    # working even when USE_DISTRIBUTED=0.  Feel free to add more
151    # stubs as necessary.
152    # We cannot define stubs directly because they confuse pyre
153
154    class _ProcessGroupStub:
155        pass
156
157    sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub  # type: ignore[attr-defined]
158