xref: /aosp_15_r20/external/pytorch/torch/distributed/rendezvous.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2try:
3    from urllib.parse import urlparse, urlunparse
4except ImportError as e:
5    raise ImportError(
6        "urllib cannot be found, urlparse from python2 is no longer supported."
7    ) from e
8
9import numbers
10import os
11import sys
12from datetime import timedelta
13from typing import Callable, Dict, Iterator, Optional, Tuple
14
15from torch.distributed import FileStore, PrefixStore, Store, TCPStore
16
17from .constants import default_pg_timeout
18
19
20_rendezvous_handlers: Dict[str, Callable[..., Iterator[Tuple[Store, int, int]]]] = {}
21
22__all__ = ["register_rendezvous_handler", "rendezvous"]
23
24
25def register_rendezvous_handler(scheme, handler):
26    """
27    Register a new rendezvous handler.
28
29    Before we can run collective algorithms, participating processes
30    need to find each other and exchange information to be able to
31    communicate. We call this process rendezvous.
32
33    The outcome of the rendezvous process is a triplet containing a
34    shared key/value store, the rank of the process, and the total
35    number of participating processes.
36
37    If none of the bundled rendezvous methods apply to your execution
38    environment you can opt to register your own rendezvous handler.
39    Pick a unique name and use the URL scheme to identify it when
40    calling the `rendezvous()` function.
41
42    Args:
43        scheme (str): URL scheme to identify your rendezvous handler.
44        handler (function): Handler that is invoked when the
45            `rendezvous()` function is called with a URL that uses
46            the corresponding scheme. It must be a generator function
47            that yields the triplet.
48    """
49    global _rendezvous_handlers
50    if scheme in _rendezvous_handlers:
51        raise RuntimeError(f"Rendezvous handler for {scheme}:// already registered")
52    _rendezvous_handlers[scheme] = handler
53
54
55# Query will have format "rank=0&world_size=1" and is
56# converted into {"rank": 0, "world_size": 1}
57def _query_to_dict(query: str) -> Dict[str, str]:
58    return {
59        pair[0]: pair[1]
60        for pair in (pair.split("=") for pair in filter(None, query.split("&")))
61    }
62
63
64def _get_use_libuv_from_query_dict(query_dict: Dict[str, str]) -> bool:
65    # libuv is the default backend for TCPStore. To enable the non-libuv backend,
66    # user can explicitly specify ``use_libuv=0`` in the URL parameter.
67    return query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "1")) == "1"
68
69
70def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwargs):
71    result = urlparse(url)
72    if world_size_opt is None:
73        world_size = -1
74        if result.scheme == "env":
75            rank = int(os.environ.get("RANK", rank))
76            # If the world_size env variable is not present then it is a dynamic group
77            world_size = int(os.environ.get("WORLD_SIZE", world_size))
78    else:
79        world_size = world_size_opt
80    if rank != -1 or world_size != -1 or world_size_opt is None:
81        query_dict = _query_to_dict(result.query)
82        assert (
83            "rank" not in query_dict and "world_size" not in query_dict
84        ), f"The url: {url} has node-specific arguments(rank, world_size) already."
85        if rank != -1:
86            query_dict["rank"] = str(rank)
87        if world_size != -1 or world_size_opt is None:
88            query_dict["world_size"] = str(world_size)
89        result = result._replace(
90            query=f"{'&'.join([f'{k}={v}' for k, v in query_dict.items()])}"
91        )
92        url = urlunparse(result)
93
94    if result.scheme not in _rendezvous_handlers:
95        raise RuntimeError(f"No rendezvous handler for {result.scheme}://")
96    return _rendezvous_handlers[result.scheme](url, **kwargs)
97
98
99def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
100    if not isinstance(url, (str, bytes)):
101        raise RuntimeError(f"`url` must be a string. {type(url)}: {url}")
102
103    if not isinstance(rank, numbers.Integral):
104        raise RuntimeError(f"`rank` must be an integer. {rank}")
105
106    if not isinstance(world_size, numbers.Integral):
107        raise RuntimeError(f"`world_size` must be an integer. {world_size}")
108
109    return _rendezvous_helper(url, rank, world_size, **kwargs)
110
111
112def _create_store_from_options(backend_options, rank):
113    store, _, _ = next(_rendezvous_helper(backend_options.init_method, rank, None))
114    return store
115
116
117def _rendezvous_error(msg):
118    return ValueError("Error initializing torch.distributed using " + msg)
119
120
121def _file_rendezvous_handler(url: str, **kwargs):
122    def _error(msg):
123        return _rendezvous_error("file:// rendezvous: " + msg)
124
125    result = urlparse(url)
126    path = result.path
127    if sys.platform == "win32":
128        import urllib.request
129
130        full_path = result.netloc + result.path
131        path = urllib.request.url2pathname(full_path)
132        if path:
133            # Normalizing an empty string produces ".", which is not expected.
134            path = os.path.normpath(path)
135
136    if not path:
137        raise _error("path missing")
138    query_dict = _query_to_dict(result.query)
139    if "rank" not in query_dict:
140        raise _error("rank parameter missing")
141    if "world_size" not in query_dict:
142        raise _error("world size parameter missing")
143
144    rank = int(query_dict["rank"])
145    world_size = int(query_dict["world_size"])
146    store = FileStore(path, world_size)
147    yield (store, rank, world_size)
148
149    # If this configuration is invalidated, there is nothing we can do about it
150    raise RuntimeError("Unable to perform rerendezvous using file:// method")
151
152
153def _torchelastic_use_agent_store() -> bool:
154    return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True)
155
156
157def _create_c10d_store(
158    hostname, port, rank, world_size, timeout, use_libuv=True
159) -> Store:
160    """
161    Smartly creates a c10d Store object on ``rank`` based on whether we need to re-use agent store.
162
163    The TCPStore server is assumed to be hosted
164    on ``hostname:port``.
165
166    By default, the TCPStore server uses the asynchronous implementation
167    ``LibUVStoreDaemon`` which utilizes libuv.
168
169    If ``torchelastic_use_agent_store()`` is ``True``, then it is assumed that
170    the agent leader (node rank 0) hosts the TCPStore server (for which the
171    endpoint is specified by the given ``hostname:port``). Hence
172    ALL ranks will create and return a TCPStore client (e.g. ``start_daemon=False``).
173
174    If ``torchelastic_use_agent_store()`` is ``False``, then rank 0 will host
175    the TCPStore (with multi-tenancy) and it is assumed that rank 0's hostname
176    and port are correctly passed via ``hostname`` and ``port``. All
177    non-zero ranks will create and return a TCPStore client.
178    """
179    # check if port is uint16_t
180    if not 0 <= port < 2**16:
181        raise ValueError(f"port must have value from 0 to 65535 but was {port}.")
182
183    if _torchelastic_use_agent_store():
184        attempt = os.environ["TORCHELASTIC_RESTART_COUNT"]
185        tcp_store = TCPStore(hostname, port, world_size, False, timeout)
186        return PrefixStore(f"/worker/attempt_{attempt}", tcp_store)
187    else:
188        start_daemon = rank == 0
189        return TCPStore(
190            hostname,
191            port,
192            world_size,
193            start_daemon,
194            timeout,
195            multi_tenant=True,
196            use_libuv=use_libuv,
197        )
198
199
200def _tcp_rendezvous_handler(
201    url: str, timeout: timedelta = default_pg_timeout, **kwargs
202):
203    def _error(msg):
204        return _rendezvous_error("tcp:// rendezvous: " + msg)
205
206    result = urlparse(url)
207    if not result.port:
208        raise _error("port number missing")
209    query_dict = _query_to_dict(result.query)
210    if "rank" not in query_dict:
211        raise _error("rank parameter missing")
212    if "world_size" not in query_dict:
213        raise _error("world size parameter missing")
214
215    rank = int(query_dict["rank"])
216    world_size = int(query_dict["world_size"])
217    use_libuv = _get_use_libuv_from_query_dict(query_dict)
218
219    assert result.hostname is not None
220
221    store = _create_c10d_store(
222        result.hostname, result.port, rank, world_size, timeout, use_libuv
223    )
224
225    yield (store, rank, world_size)
226
227    # If this configuration is invalidated, there is nothing we can do about it
228    raise RuntimeError("Unable to perform re-rendezvous using tcp:// method")
229
230
231def _env_rendezvous_handler(
232    url: str, timeout: timedelta = default_pg_timeout, **kwargs
233):
234    def _error(msg):
235        return _rendezvous_error("env:// rendezvous: " + msg)
236
237    def _env_error(var):
238        return _error(f"environment variable {var} expected, but not set")
239
240    def _get_env_or_raise(env_var: str) -> str:
241        env_val = os.environ.get(env_var, None)
242        if not env_val:
243            raise _env_error(env_var)
244        else:
245            return env_val
246
247    result = urlparse(url)
248    query_dict = _query_to_dict(result.query)
249
250    rank: int
251    world_size: int
252    master_port: int
253    master_addr: str
254
255    if "rank" in query_dict:
256        rank = int(query_dict["rank"])
257    else:
258        rank = int(_get_env_or_raise("RANK"))
259
260    if "world_size" in query_dict:
261        world_size = int(query_dict["world_size"])
262    else:
263        world_size = int(_get_env_or_raise("WORLD_SIZE"))
264
265    master_addr = _get_env_or_raise("MASTER_ADDR")
266    master_port = int(_get_env_or_raise("MASTER_PORT"))
267    use_libuv = _get_use_libuv_from_query_dict(query_dict)
268
269    store = _create_c10d_store(
270        master_addr, master_port, rank, world_size, timeout, use_libuv
271    )
272
273    yield (store, rank, world_size)
274
275    # If this configuration is invalidated, there is nothing we can do about it
276    raise RuntimeError("Unable to perform re-rendezvous using env:// method")
277
278
279register_rendezvous_handler("tcp", _tcp_rendezvous_handler)
280register_rendezvous_handler("env", _env_rendezvous_handler)
281register_rendezvous_handler("file", _file_rendezvous_handler)
282