1# mypy: allow-untyped-defs 2import logging 3import pdb 4import sys 5import traceback 6import typing 7 8import torch 9 10 11log = logging.getLogger(__name__) 12 13 14def is_available() -> bool: 15 """ 16 Return ``True`` if the distributed package is available. 17 18 Otherwise, 19 ``torch.distributed`` does not expose any other APIs. Currently, 20 ``torch.distributed`` is available on Linux, MacOS and Windows. Set 21 ``USE_DISTRIBUTED=1`` to enable it when building PyTorch from source. 22 Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and Windows, 23 ``USE_DISTRIBUTED=0`` for MacOS. 24 """ 25 return hasattr(torch._C, "_c10d_init") 26 27 28if is_available() and not torch._C._c10d_init(): 29 raise RuntimeError("Failed to initialize torch.distributed") 30 31# Custom Runtime Errors thrown from the distributed package 32DistError = torch._C._DistError 33DistBackendError = torch._C._DistBackendError 34DistNetworkError = torch._C._DistNetworkError 35DistStoreError = torch._C._DistStoreError 36 37if is_available(): 38 from torch._C._distributed_c10d import ( 39 _broadcast_coalesced, 40 _compute_bucket_assignment_by_size, 41 _ControlCollectives, 42 _DEFAULT_FIRST_BUCKET_BYTES, 43 _make_nccl_premul_sum, 44 _register_builtin_comm_hook, 45 _register_comm_hook, 46 _StoreCollectives, 47 _test_python_store, 48 _verify_params_across_processes, 49 Backend as _Backend, 50 BuiltinCommHookType, 51 DebugLevel, 52 FileStore, 53 get_debug_level, 54 GradBucket, 55 Logger, 56 PrefixStore, 57 ProcessGroup as ProcessGroup, 58 Reducer, 59 set_debug_level, 60 set_debug_level_from_env, 61 Store, 62 TCPStore, 63 Work as _Work, 64 ) 65 66 class _DistributedPdb(pdb.Pdb): 67 """ 68 Supports using PDB from inside a multiprocessing child process. 69 70 Usage: 71 _DistributedPdb().set_trace() 72 """ 73 74 def interaction(self, *args, **kwargs): 75 _stdin = sys.stdin 76 try: 77 sys.stdin = open("/dev/stdin") 78 pdb.Pdb.interaction(self, *args, **kwargs) 79 finally: 80 sys.stdin = _stdin 81 82 _breakpoint_cache: typing.Dict[int, typing.Any] = {} 83 84 def breakpoint(rank: int = 0, skip: int = 0): 85 """ 86 Set a breakpoint, but only on a single rank. All other ranks will wait for you to be 87 done with the breakpoint before continuing. 88 89 Args: 90 rank (int): Which rank to break on. Default: ``0`` 91 skip (int): Skip the first ``skip`` calls to this breakpoint. Default: ``0``. 92 """ 93 if skip > 0: 94 key = hash(str(traceback.format_exc())) 95 counter = _breakpoint_cache.get(key, 0) + 1 96 _breakpoint_cache[key] = counter 97 if counter <= skip: 98 log.warning("Skip the breakpoint, counter=%d", counter) 99 return 100 101 if get_rank() == rank: 102 pdb = _DistributedPdb() 103 pdb.message( 104 "\n!!! ATTENTION !!!\n\n" 105 f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n" 106 ) 107 pdb.set_trace() 108 # If Meta/Python keys are in the TLS, we want to make sure that we ignore them 109 # and hit the (default) CPU/CUDA implementation of barrier. 110 meta_in_tls = torch._C._meta_in_tls_dispatch_include() 111 guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] 112 torch._C._set_meta_in_tls_dispatch_include(False) 113 try: 114 barrier() 115 finally: 116 torch._C._set_meta_in_tls_dispatch_include(meta_in_tls) 117 del guard 118 119 if sys.platform != "win32": 120 from torch._C._distributed_c10d import HashStore 121 122 from .device_mesh import DeviceMesh, init_device_mesh 123 124 # Variables prefixed with underscore are not auto imported 125 # See the comment in `distributed_c10d.py` above `_backend` on why we expose 126 # this. 127 from .distributed_c10d import * # noqa: F403 128 from .distributed_c10d import ( 129 _all_gather_base, 130 _coalescing_manager, 131 _CoalescingManager, 132 _create_process_group_wrapper, 133 _get_process_group_name, 134 _rank_not_in_group, 135 _reduce_scatter_base, 136 get_node_local_rank, 137 ) 138 from .remote_device import _remote_device 139 from .rendezvous import ( 140 _create_store_from_options, 141 register_rendezvous_handler, 142 rendezvous, 143 ) 144 145 set_debug_level_from_env() 146 147else: 148 # This stub is sufficient to get 149 # python test/test_public_bindings.py -k test_correct_module_names 150 # working even when USE_DISTRIBUTED=0. Feel free to add more 151 # stubs as necessary. 152 # We cannot define stubs directly because they confuse pyre 153 154 class _ProcessGroupStub: 155 pass 156 157 sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined] 158