1# mypy: allow-untyped-defs 2import logging 3import os 4import threading 5import warnings 6from datetime import timedelta 7from typing import Generator, Tuple 8from urllib.parse import urlparse 9 10import torch 11import torch.distributed as dist 12 13 14__all__ = ["is_available"] 15 16 17logger = logging.getLogger(__name__) 18 19 20_init_counter = 0 21_init_counter_lock = threading.Lock() 22 23 24def is_available() -> bool: 25 return hasattr(torch._C, "_rpc_init") 26 27 28if is_available() and not torch._C._rpc_init(): 29 raise RuntimeError("Failed to initialize torch.distributed.rpc") 30 31 32if is_available(): 33 import numbers 34 35 import torch.distributed.autograd as dist_autograd 36 from torch._C._distributed_c10d import Store 37 from torch._C._distributed_rpc import ( # noqa: F401 38 _cleanup_python_rpc_handler, 39 _DEFAULT_INIT_METHOD, 40 _DEFAULT_NUM_WORKER_THREADS, 41 _DEFAULT_RPC_TIMEOUT_SEC, 42 _delete_all_user_and_unforked_owner_rrefs, 43 _destroy_rref_context, 44 _disable_jit_rref_pickle, 45 _disable_server_process_global_profiler, 46 _enable_jit_rref_pickle, 47 _enable_server_process_global_profiler, 48 _get_current_rpc_agent, 49 _invoke_remote_builtin, 50 _invoke_remote_python_udf, 51 _invoke_remote_torchscript, 52 _invoke_rpc_builtin, 53 _invoke_rpc_python_udf, 54 _invoke_rpc_torchscript, 55 _is_current_rpc_agent_set, 56 _reset_current_rpc_agent, 57 _rref_context_get_debug_info, 58 _set_and_start_rpc_agent, 59 _set_profiler_node_id, 60 _set_rpc_timeout, 61 _TensorPipeRpcBackendOptionsBase, 62 _UNSET_RPC_TIMEOUT, 63 enable_gil_profiling, 64 get_rpc_timeout, 65 PyRRef, 66 RemoteProfilerManager, 67 RpcAgent, 68 RpcBackendOptions, 69 TensorPipeAgent, 70 WorkerInfo, 71 ) 72 73 from . import api, backend_registry, functions 74 from .api import * # noqa: F401,F403 75 from .backend_registry import BackendType 76 from .options import TensorPipeRpcBackendOptions # noqa: F401 77 from .server_process_global_profiler import _server_process_global_profile 78 79 rendezvous_iterator: Generator[Tuple[Store, int, int], None, None] 80 81 __all__ += ["init_rpc", "BackendType", "TensorPipeRpcBackendOptions"] 82 __all__ = __all__ + api.__all__ + backend_registry.__all__ # noqa: PLE0605 83 84 def init_rpc( 85 name, 86 backend=None, 87 rank=-1, 88 world_size=None, 89 rpc_backend_options=None, 90 ): 91 r""" 92 Initializes RPC primitives such as the local RPC agent 93 and distributed autograd, which immediately makes the current 94 process ready to send and receive RPCs. 95 96 Args: 97 name (str): a globally unique name of this node. (e.g., 98 ``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``) 99 Name can only contain number, alphabet, underscore, colon, 100 and/or dash, and must be shorter than 128 characters. 101 backend (BackendType, optional): The type of RPC backend 102 implementation. Supported values is 103 ``BackendType.TENSORPIPE`` (the default). 104 See :ref:`rpc-backends` for more information. 105 rank (int): a globally unique id/rank of this node. 106 world_size (int): The number of workers in the group. 107 rpc_backend_options (RpcBackendOptions, optional): The options 108 passed to the RpcAgent constructor. It must be an agent-specific 109 subclass of :class:`~torch.distributed.rpc.RpcBackendOptions` 110 and contains agent-specific initialization configurations. By 111 default, for all agents, it sets the default timeout to 60 112 seconds and performs the rendezvous with an underlying process 113 group initialized using ``init_method = "env://"``, 114 meaning that environment variables ``MASTER_ADDR`` and 115 ``MASTER_PORT`` need to be set properly. See 116 :ref:`rpc-backends` for more information and find which options 117 are available. 118 """ 119 torch._C._log_api_usage_once("torch.distributed.init_rpc") 120 if backend is not None and not isinstance( 121 backend, backend_registry.BackendType 122 ): 123 raise TypeError("Argument backend must be a member of BackendType") 124 125 if rpc_backend_options is not None and not isinstance( 126 rpc_backend_options, RpcBackendOptions 127 ): 128 raise TypeError( 129 "Argument rpc_backend_options must be an instance of RpcBackendOptions" 130 ) 131 132 # Try to detect the backend from the options 133 if backend is None and rpc_backend_options is not None: 134 for candidate_backend in BackendType: 135 if isinstance( 136 rpc_backend_options, 137 type( 138 backend_registry.construct_rpc_backend_options( 139 candidate_backend 140 ) 141 ), 142 ): 143 backend = candidate_backend 144 break 145 else: 146 raise TypeError( 147 f"Could not infer backend for options {rpc_backend_options}" 148 ) 149 # Ignore type error because mypy doesn't handle dynamically generated type objects (#4865) 150 if backend != BackendType.TENSORPIPE: # type: ignore[attr-defined] 151 logger.warning( 152 "RPC was initialized with no explicit backend but with options " # type: ignore[attr-defined] 153 "corresponding to %(backend)s, hence that backend will be used " 154 "instead of the default BackendType.TENSORPIPE. To silence this " 155 "warning pass `backend=%(backend)s` explicitly.", 156 {"backend": backend}, 157 ) 158 159 if backend is None: 160 backend = BackendType.TENSORPIPE # type: ignore[attr-defined] 161 162 if rpc_backend_options is None: 163 # default construct a set of RPC backend options. 164 rpc_backend_options = backend_registry.construct_rpc_backend_options( 165 backend 166 ) 167 168 # Create store, performs rendezvous for static RPC group. 169 if not world_size: 170 # If world_size is not set in construction and also not set in environment variables 171 # The store will be created for the dynamic group setting 172 store = dist._create_store_from_options(rpc_backend_options, rank) 173 else: 174 # This rendezvous state sometimes is destroyed before all processes 175 # finishing handshaking. To avoid that issue, we make it global to 176 # keep it alive. 177 global rendezvous_iterator 178 rendezvous_iterator = dist.rendezvous( 179 rpc_backend_options.init_method, rank=rank, world_size=world_size 180 ) 181 store, _, _ = next(rendezvous_iterator) 182 # Use same timeout as RPC. 183 store.set_timeout(timedelta(seconds=rpc_backend_options.rpc_timeout)) 184 185 # Use a PrefixStore to distinguish multiple invocations. 186 with _init_counter_lock: 187 global _init_counter 188 store = dist.PrefixStore(str(f"rpc_prefix_{_init_counter}"), store) 189 _init_counter += 1 190 191 # Initialize autograd before RPC since _init_rpc_backend guarantees all 192 # processes sync via the store. If we initialize autograd after RPC, 193 # there could be a race where some nodes might have initialized autograd 194 # and others might not have. As a result, a node calling 195 # torch.distributed.autograd.backward() would run into errors since 196 # other nodes might not have been initialized. 197 dist_autograd._init(rank) 198 199 _set_profiler_node_id(rank) 200 # Initialize RPC. 201 _init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options) 202 203 def _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options): 204 type_mapping = { 205 backend: backend_registry.BackendType, 206 store: dist.Store, 207 name: str, 208 rank: numbers.Integral, 209 # world_size can be None for a dynamic group 210 world_size: (numbers.Integral, type(None)), 211 rpc_backend_options: RpcBackendOptions, 212 } 213 for arg, arg_type in type_mapping.items(): 214 if not isinstance(arg, arg_type): # type: ignore[arg-type] 215 raise RuntimeError( 216 f"Argument {arg} must be of type {arg_type} but got type {type(arg)}" 217 ) 218 219 def _init_rpc_backend( 220 backend=BackendType.TENSORPIPE, # type: ignore[attr-defined] 221 store=None, 222 name=None, 223 rank=-1, 224 world_size=None, 225 rpc_backend_options=None, 226 ): 227 _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options) 228 229 if _is_current_rpc_agent_set(): 230 raise RuntimeError("RPC is already initialized") 231 232 # Initialize RPC. 233 rpc_agent = backend_registry.init_backend( 234 backend, 235 store=store, 236 name=name, 237 rank=rank, 238 world_size=world_size, 239 rpc_backend_options=rpc_backend_options, 240 ) 241 242 api._init_rpc_states(rpc_agent) 243 244 @api._require_initialized 245 def _get_debug_info(): 246 info = _rref_context_get_debug_info() 247 info.update(api._get_current_rpc_agent().get_debug_info()) 248 info.update(dist_autograd._get_debug_info()) 249 return info 250