xref: /aosp_15_r20/external/pytorch/torch/distributed/rpc/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import logging
3import os
4import threading
5import warnings
6from datetime import timedelta
7from typing import Generator, Tuple
8from urllib.parse import urlparse
9
10import torch
11import torch.distributed as dist
12
13
14__all__ = ["is_available"]
15
16
17logger = logging.getLogger(__name__)
18
19
20_init_counter = 0
21_init_counter_lock = threading.Lock()
22
23
24def is_available() -> bool:
25    return hasattr(torch._C, "_rpc_init")
26
27
28if is_available() and not torch._C._rpc_init():
29    raise RuntimeError("Failed to initialize torch.distributed.rpc")
30
31
32if is_available():
33    import numbers
34
35    import torch.distributed.autograd as dist_autograd
36    from torch._C._distributed_c10d import Store
37    from torch._C._distributed_rpc import (  # noqa: F401
38        _cleanup_python_rpc_handler,
39        _DEFAULT_INIT_METHOD,
40        _DEFAULT_NUM_WORKER_THREADS,
41        _DEFAULT_RPC_TIMEOUT_SEC,
42        _delete_all_user_and_unforked_owner_rrefs,
43        _destroy_rref_context,
44        _disable_jit_rref_pickle,
45        _disable_server_process_global_profiler,
46        _enable_jit_rref_pickle,
47        _enable_server_process_global_profiler,
48        _get_current_rpc_agent,
49        _invoke_remote_builtin,
50        _invoke_remote_python_udf,
51        _invoke_remote_torchscript,
52        _invoke_rpc_builtin,
53        _invoke_rpc_python_udf,
54        _invoke_rpc_torchscript,
55        _is_current_rpc_agent_set,
56        _reset_current_rpc_agent,
57        _rref_context_get_debug_info,
58        _set_and_start_rpc_agent,
59        _set_profiler_node_id,
60        _set_rpc_timeout,
61        _TensorPipeRpcBackendOptionsBase,
62        _UNSET_RPC_TIMEOUT,
63        enable_gil_profiling,
64        get_rpc_timeout,
65        PyRRef,
66        RemoteProfilerManager,
67        RpcAgent,
68        RpcBackendOptions,
69        TensorPipeAgent,
70        WorkerInfo,
71    )
72
73    from . import api, backend_registry, functions
74    from .api import *  # noqa: F401,F403
75    from .backend_registry import BackendType
76    from .options import TensorPipeRpcBackendOptions  # noqa: F401
77    from .server_process_global_profiler import _server_process_global_profile
78
79    rendezvous_iterator: Generator[Tuple[Store, int, int], None, None]
80
81    __all__ += ["init_rpc", "BackendType", "TensorPipeRpcBackendOptions"]
82    __all__ = __all__ + api.__all__ + backend_registry.__all__  # noqa: PLE0605
83
84    def init_rpc(
85        name,
86        backend=None,
87        rank=-1,
88        world_size=None,
89        rpc_backend_options=None,
90    ):
91        r"""
92        Initializes RPC primitives such as the local RPC agent
93        and distributed autograd, which immediately makes the current
94        process ready to send and receive RPCs.
95
96        Args:
97            name (str): a globally unique name of this node. (e.g.,
98                ``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``)
99                Name can only contain number, alphabet, underscore, colon,
100                and/or dash, and must be shorter than 128 characters.
101            backend (BackendType, optional): The type of RPC backend
102                implementation. Supported values is
103                ``BackendType.TENSORPIPE`` (the default).
104                See :ref:`rpc-backends` for more information.
105            rank (int): a globally unique id/rank of this node.
106            world_size (int): The number of workers in the group.
107            rpc_backend_options (RpcBackendOptions, optional): The options
108                passed to the RpcAgent constructor. It must be an agent-specific
109                subclass of :class:`~torch.distributed.rpc.RpcBackendOptions`
110                and contains agent-specific initialization configurations. By
111                default, for all agents, it sets the default timeout to 60
112                seconds and performs the rendezvous with an underlying process
113                group initialized using ``init_method = "env://"``,
114                meaning that environment variables ``MASTER_ADDR`` and
115                ``MASTER_PORT`` need to be set properly. See
116                :ref:`rpc-backends` for more information and find which options
117                are available.
118        """
119        torch._C._log_api_usage_once("torch.distributed.init_rpc")
120        if backend is not None and not isinstance(
121            backend, backend_registry.BackendType
122        ):
123            raise TypeError("Argument backend must be a member of BackendType")
124
125        if rpc_backend_options is not None and not isinstance(
126            rpc_backend_options, RpcBackendOptions
127        ):
128            raise TypeError(
129                "Argument rpc_backend_options must be an instance of RpcBackendOptions"
130            )
131
132        # Try to detect the backend from the options
133        if backend is None and rpc_backend_options is not None:
134            for candidate_backend in BackendType:
135                if isinstance(
136                    rpc_backend_options,
137                    type(
138                        backend_registry.construct_rpc_backend_options(
139                            candidate_backend
140                        )
141                    ),
142                ):
143                    backend = candidate_backend
144                    break
145            else:
146                raise TypeError(
147                    f"Could not infer backend for options {rpc_backend_options}"
148                )
149            # Ignore type error because mypy doesn't handle dynamically generated type objects (#4865)
150            if backend != BackendType.TENSORPIPE:  # type: ignore[attr-defined]
151                logger.warning(
152                    "RPC was initialized with no explicit backend but with options "  # type: ignore[attr-defined]
153                    "corresponding to %(backend)s, hence that backend will be used "
154                    "instead of the default BackendType.TENSORPIPE. To silence this "
155                    "warning pass `backend=%(backend)s` explicitly.",
156                    {"backend": backend},
157                )
158
159        if backend is None:
160            backend = BackendType.TENSORPIPE  # type: ignore[attr-defined]
161
162        if rpc_backend_options is None:
163            # default construct a set of RPC backend options.
164            rpc_backend_options = backend_registry.construct_rpc_backend_options(
165                backend
166            )
167
168        # Create store, performs rendezvous for static RPC group.
169        if not world_size:
170            # If world_size is not set in construction and also not set in environment variables
171            # The store will be created for the dynamic group setting
172            store = dist._create_store_from_options(rpc_backend_options, rank)
173        else:
174            # This rendezvous state sometimes is destroyed before all processes
175            # finishing handshaking. To avoid that issue, we make it global to
176            # keep it alive.
177            global rendezvous_iterator
178            rendezvous_iterator = dist.rendezvous(
179                rpc_backend_options.init_method, rank=rank, world_size=world_size
180            )
181            store, _, _ = next(rendezvous_iterator)
182        # Use same timeout as RPC.
183        store.set_timeout(timedelta(seconds=rpc_backend_options.rpc_timeout))
184
185        # Use a PrefixStore to distinguish multiple invocations.
186        with _init_counter_lock:
187            global _init_counter
188            store = dist.PrefixStore(str(f"rpc_prefix_{_init_counter}"), store)
189            _init_counter += 1
190
191        # Initialize autograd before RPC since _init_rpc_backend guarantees all
192        # processes sync via the store. If we initialize autograd after RPC,
193        # there could be a race where some nodes might have initialized autograd
194        # and others might not have. As a result, a node calling
195        # torch.distributed.autograd.backward() would run into errors since
196        # other nodes might not have been initialized.
197        dist_autograd._init(rank)
198
199        _set_profiler_node_id(rank)
200        # Initialize RPC.
201        _init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options)
202
203    def _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options):
204        type_mapping = {
205            backend: backend_registry.BackendType,
206            store: dist.Store,
207            name: str,
208            rank: numbers.Integral,
209            # world_size can be None for a dynamic group
210            world_size: (numbers.Integral, type(None)),
211            rpc_backend_options: RpcBackendOptions,
212        }
213        for arg, arg_type in type_mapping.items():
214            if not isinstance(arg, arg_type):  # type: ignore[arg-type]
215                raise RuntimeError(
216                    f"Argument {arg} must be of type {arg_type} but got type {type(arg)}"
217                )
218
219    def _init_rpc_backend(
220        backend=BackendType.TENSORPIPE,  # type: ignore[attr-defined]
221        store=None,
222        name=None,
223        rank=-1,
224        world_size=None,
225        rpc_backend_options=None,
226    ):
227        _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options)
228
229        if _is_current_rpc_agent_set():
230            raise RuntimeError("RPC is already initialized")
231
232        # Initialize RPC.
233        rpc_agent = backend_registry.init_backend(
234            backend,
235            store=store,
236            name=name,
237            rank=rank,
238            world_size=world_size,
239            rpc_backend_options=rpc_backend_options,
240        )
241
242        api._init_rpc_states(rpc_agent)
243
244    @api._require_initialized
245    def _get_debug_info():
246        info = _rref_context_get_debug_info()
247        info.update(api._get_current_rpc_agent().get_debug_info())
248        info.update(dist_autograd._get_debug_info())
249        return info
250