xref: /aosp_15_r20/external/pytorch/torch/distributed/rendezvous.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workertry:
3*da0073e9SAndroid Build Coastguard Worker    from urllib.parse import urlparse, urlunparse
4*da0073e9SAndroid Build Coastguard Workerexcept ImportError as e:
5*da0073e9SAndroid Build Coastguard Worker    raise ImportError(
6*da0073e9SAndroid Build Coastguard Worker        "urllib cannot be found, urlparse from python2 is no longer supported."
7*da0073e9SAndroid Build Coastguard Worker    ) from e
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport numbers
10*da0073e9SAndroid Build Coastguard Workerimport os
11*da0073e9SAndroid Build Coastguard Workerimport sys
12*da0073e9SAndroid Build Coastguard Workerfrom datetime import timedelta
13*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, Dict, Iterator, Optional, Tuple
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed import FileStore, PrefixStore, Store, TCPStore
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerfrom .constants import default_pg_timeout
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker_rendezvous_handlers: Dict[str, Callable[..., Iterator[Tuple[Store, int, int]]]] = {}
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker__all__ = ["register_rendezvous_handler", "rendezvous"]
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Workerdef register_rendezvous_handler(scheme, handler):
26*da0073e9SAndroid Build Coastguard Worker    """
27*da0073e9SAndroid Build Coastguard Worker    Register a new rendezvous handler.
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker    Before we can run collective algorithms, participating processes
30*da0073e9SAndroid Build Coastguard Worker    need to find each other and exchange information to be able to
31*da0073e9SAndroid Build Coastguard Worker    communicate. We call this process rendezvous.
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker    The outcome of the rendezvous process is a triplet containing a
34*da0073e9SAndroid Build Coastguard Worker    shared key/value store, the rank of the process, and the total
35*da0073e9SAndroid Build Coastguard Worker    number of participating processes.
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker    If none of the bundled rendezvous methods apply to your execution
38*da0073e9SAndroid Build Coastguard Worker    environment you can opt to register your own rendezvous handler.
39*da0073e9SAndroid Build Coastguard Worker    Pick a unique name and use the URL scheme to identify it when
40*da0073e9SAndroid Build Coastguard Worker    calling the `rendezvous()` function.
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker    Args:
43*da0073e9SAndroid Build Coastguard Worker        scheme (str): URL scheme to identify your rendezvous handler.
44*da0073e9SAndroid Build Coastguard Worker        handler (function): Handler that is invoked when the
45*da0073e9SAndroid Build Coastguard Worker            `rendezvous()` function is called with a URL that uses
46*da0073e9SAndroid Build Coastguard Worker            the corresponding scheme. It must be a generator function
47*da0073e9SAndroid Build Coastguard Worker            that yields the triplet.
48*da0073e9SAndroid Build Coastguard Worker    """
49*da0073e9SAndroid Build Coastguard Worker    global _rendezvous_handlers
50*da0073e9SAndroid Build Coastguard Worker    if scheme in _rendezvous_handlers:
51*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f"Rendezvous handler for {scheme}:// already registered")
52*da0073e9SAndroid Build Coastguard Worker    _rendezvous_handlers[scheme] = handler
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker# Query will have format "rank=0&world_size=1" and is
56*da0073e9SAndroid Build Coastguard Worker# converted into {"rank": 0, "world_size": 1}
57*da0073e9SAndroid Build Coastguard Workerdef _query_to_dict(query: str) -> Dict[str, str]:
58*da0073e9SAndroid Build Coastguard Worker    return {
59*da0073e9SAndroid Build Coastguard Worker        pair[0]: pair[1]
60*da0073e9SAndroid Build Coastguard Worker        for pair in (pair.split("=") for pair in filter(None, query.split("&")))
61*da0073e9SAndroid Build Coastguard Worker    }
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Workerdef _get_use_libuv_from_query_dict(query_dict: Dict[str, str]) -> bool:
65*da0073e9SAndroid Build Coastguard Worker    # libuv is the default backend for TCPStore. To enable the non-libuv backend,
66*da0073e9SAndroid Build Coastguard Worker    # user can explicitly specify ``use_libuv=0`` in the URL parameter.
67*da0073e9SAndroid Build Coastguard Worker    return query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "1")) == "1"
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Workerdef _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwargs):
71*da0073e9SAndroid Build Coastguard Worker    result = urlparse(url)
72*da0073e9SAndroid Build Coastguard Worker    if world_size_opt is None:
73*da0073e9SAndroid Build Coastguard Worker        world_size = -1
74*da0073e9SAndroid Build Coastguard Worker        if result.scheme == "env":
75*da0073e9SAndroid Build Coastguard Worker            rank = int(os.environ.get("RANK", rank))
76*da0073e9SAndroid Build Coastguard Worker            # If the world_size env variable is not present then it is a dynamic group
77*da0073e9SAndroid Build Coastguard Worker            world_size = int(os.environ.get("WORLD_SIZE", world_size))
78*da0073e9SAndroid Build Coastguard Worker    else:
79*da0073e9SAndroid Build Coastguard Worker        world_size = world_size_opt
80*da0073e9SAndroid Build Coastguard Worker    if rank != -1 or world_size != -1 or world_size_opt is None:
81*da0073e9SAndroid Build Coastguard Worker        query_dict = _query_to_dict(result.query)
82*da0073e9SAndroid Build Coastguard Worker        assert (
83*da0073e9SAndroid Build Coastguard Worker            "rank" not in query_dict and "world_size" not in query_dict
84*da0073e9SAndroid Build Coastguard Worker        ), f"The url: {url} has node-specific arguments(rank, world_size) already."
85*da0073e9SAndroid Build Coastguard Worker        if rank != -1:
86*da0073e9SAndroid Build Coastguard Worker            query_dict["rank"] = str(rank)
87*da0073e9SAndroid Build Coastguard Worker        if world_size != -1 or world_size_opt is None:
88*da0073e9SAndroid Build Coastguard Worker            query_dict["world_size"] = str(world_size)
89*da0073e9SAndroid Build Coastguard Worker        result = result._replace(
90*da0073e9SAndroid Build Coastguard Worker            query=f"{'&'.join([f'{k}={v}' for k, v in query_dict.items()])}"
91*da0073e9SAndroid Build Coastguard Worker        )
92*da0073e9SAndroid Build Coastguard Worker        url = urlunparse(result)
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker    if result.scheme not in _rendezvous_handlers:
95*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f"No rendezvous handler for {result.scheme}://")
96*da0073e9SAndroid Build Coastguard Worker    return _rendezvous_handlers[result.scheme](url, **kwargs)
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Workerdef rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
100*da0073e9SAndroid Build Coastguard Worker    if not isinstance(url, (str, bytes)):
101*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f"`url` must be a string. {type(url)}: {url}")
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker    if not isinstance(rank, numbers.Integral):
104*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f"`rank` must be an integer. {rank}")
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker    if not isinstance(world_size, numbers.Integral):
107*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f"`world_size` must be an integer. {world_size}")
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker    return _rendezvous_helper(url, rank, world_size, **kwargs)
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Workerdef _create_store_from_options(backend_options, rank):
113*da0073e9SAndroid Build Coastguard Worker    store, _, _ = next(_rendezvous_helper(backend_options.init_method, rank, None))
114*da0073e9SAndroid Build Coastguard Worker    return store
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Workerdef _rendezvous_error(msg):
118*da0073e9SAndroid Build Coastguard Worker    return ValueError("Error initializing torch.distributed using " + msg)
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Workerdef _file_rendezvous_handler(url: str, **kwargs):
122*da0073e9SAndroid Build Coastguard Worker    def _error(msg):
123*da0073e9SAndroid Build Coastguard Worker        return _rendezvous_error("file:// rendezvous: " + msg)
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker    result = urlparse(url)
126*da0073e9SAndroid Build Coastguard Worker    path = result.path
127*da0073e9SAndroid Build Coastguard Worker    if sys.platform == "win32":
128*da0073e9SAndroid Build Coastguard Worker        import urllib.request
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker        full_path = result.netloc + result.path
131*da0073e9SAndroid Build Coastguard Worker        path = urllib.request.url2pathname(full_path)
132*da0073e9SAndroid Build Coastguard Worker        if path:
133*da0073e9SAndroid Build Coastguard Worker            # Normalizing an empty string produces ".", which is not expected.
134*da0073e9SAndroid Build Coastguard Worker            path = os.path.normpath(path)
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker    if not path:
137*da0073e9SAndroid Build Coastguard Worker        raise _error("path missing")
138*da0073e9SAndroid Build Coastguard Worker    query_dict = _query_to_dict(result.query)
139*da0073e9SAndroid Build Coastguard Worker    if "rank" not in query_dict:
140*da0073e9SAndroid Build Coastguard Worker        raise _error("rank parameter missing")
141*da0073e9SAndroid Build Coastguard Worker    if "world_size" not in query_dict:
142*da0073e9SAndroid Build Coastguard Worker        raise _error("world size parameter missing")
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker    rank = int(query_dict["rank"])
145*da0073e9SAndroid Build Coastguard Worker    world_size = int(query_dict["world_size"])
146*da0073e9SAndroid Build Coastguard Worker    store = FileStore(path, world_size)
147*da0073e9SAndroid Build Coastguard Worker    yield (store, rank, world_size)
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker    # If this configuration is invalidated, there is nothing we can do about it
150*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError("Unable to perform rerendezvous using file:// method")
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Workerdef _torchelastic_use_agent_store() -> bool:
154*da0073e9SAndroid Build Coastguard Worker    return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True)
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Workerdef _create_c10d_store(
158*da0073e9SAndroid Build Coastguard Worker    hostname, port, rank, world_size, timeout, use_libuv=True
159*da0073e9SAndroid Build Coastguard Worker) -> Store:
160*da0073e9SAndroid Build Coastguard Worker    """
161*da0073e9SAndroid Build Coastguard Worker    Smartly creates a c10d Store object on ``rank`` based on whether we need to re-use agent store.
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker    The TCPStore server is assumed to be hosted
164*da0073e9SAndroid Build Coastguard Worker    on ``hostname:port``.
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker    By default, the TCPStore server uses the asynchronous implementation
167*da0073e9SAndroid Build Coastguard Worker    ``LibUVStoreDaemon`` which utilizes libuv.
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker    If ``torchelastic_use_agent_store()`` is ``True``, then it is assumed that
170*da0073e9SAndroid Build Coastguard Worker    the agent leader (node rank 0) hosts the TCPStore server (for which the
171*da0073e9SAndroid Build Coastguard Worker    endpoint is specified by the given ``hostname:port``). Hence
172*da0073e9SAndroid Build Coastguard Worker    ALL ranks will create and return a TCPStore client (e.g. ``start_daemon=False``).
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker    If ``torchelastic_use_agent_store()`` is ``False``, then rank 0 will host
175*da0073e9SAndroid Build Coastguard Worker    the TCPStore (with multi-tenancy) and it is assumed that rank 0's hostname
176*da0073e9SAndroid Build Coastguard Worker    and port are correctly passed via ``hostname`` and ``port``. All
177*da0073e9SAndroid Build Coastguard Worker    non-zero ranks will create and return a TCPStore client.
178*da0073e9SAndroid Build Coastguard Worker    """
179*da0073e9SAndroid Build Coastguard Worker    # check if port is uint16_t
180*da0073e9SAndroid Build Coastguard Worker    if not 0 <= port < 2**16:
181*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"port must have value from 0 to 65535 but was {port}.")
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker    if _torchelastic_use_agent_store():
184*da0073e9SAndroid Build Coastguard Worker        attempt = os.environ["TORCHELASTIC_RESTART_COUNT"]
185*da0073e9SAndroid Build Coastguard Worker        tcp_store = TCPStore(hostname, port, world_size, False, timeout)
186*da0073e9SAndroid Build Coastguard Worker        return PrefixStore(f"/worker/attempt_{attempt}", tcp_store)
187*da0073e9SAndroid Build Coastguard Worker    else:
188*da0073e9SAndroid Build Coastguard Worker        start_daemon = rank == 0
189*da0073e9SAndroid Build Coastguard Worker        return TCPStore(
190*da0073e9SAndroid Build Coastguard Worker            hostname,
191*da0073e9SAndroid Build Coastguard Worker            port,
192*da0073e9SAndroid Build Coastguard Worker            world_size,
193*da0073e9SAndroid Build Coastguard Worker            start_daemon,
194*da0073e9SAndroid Build Coastguard Worker            timeout,
195*da0073e9SAndroid Build Coastguard Worker            multi_tenant=True,
196*da0073e9SAndroid Build Coastguard Worker            use_libuv=use_libuv,
197*da0073e9SAndroid Build Coastguard Worker        )
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Workerdef _tcp_rendezvous_handler(
201*da0073e9SAndroid Build Coastguard Worker    url: str, timeout: timedelta = default_pg_timeout, **kwargs
202*da0073e9SAndroid Build Coastguard Worker):
203*da0073e9SAndroid Build Coastguard Worker    def _error(msg):
204*da0073e9SAndroid Build Coastguard Worker        return _rendezvous_error("tcp:// rendezvous: " + msg)
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker    result = urlparse(url)
207*da0073e9SAndroid Build Coastguard Worker    if not result.port:
208*da0073e9SAndroid Build Coastguard Worker        raise _error("port number missing")
209*da0073e9SAndroid Build Coastguard Worker    query_dict = _query_to_dict(result.query)
210*da0073e9SAndroid Build Coastguard Worker    if "rank" not in query_dict:
211*da0073e9SAndroid Build Coastguard Worker        raise _error("rank parameter missing")
212*da0073e9SAndroid Build Coastguard Worker    if "world_size" not in query_dict:
213*da0073e9SAndroid Build Coastguard Worker        raise _error("world size parameter missing")
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker    rank = int(query_dict["rank"])
216*da0073e9SAndroid Build Coastguard Worker    world_size = int(query_dict["world_size"])
217*da0073e9SAndroid Build Coastguard Worker    use_libuv = _get_use_libuv_from_query_dict(query_dict)
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker    assert result.hostname is not None
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker    store = _create_c10d_store(
222*da0073e9SAndroid Build Coastguard Worker        result.hostname, result.port, rank, world_size, timeout, use_libuv
223*da0073e9SAndroid Build Coastguard Worker    )
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker    yield (store, rank, world_size)
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker    # If this configuration is invalidated, there is nothing we can do about it
228*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError("Unable to perform re-rendezvous using tcp:// method")
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Workerdef _env_rendezvous_handler(
232*da0073e9SAndroid Build Coastguard Worker    url: str, timeout: timedelta = default_pg_timeout, **kwargs
233*da0073e9SAndroid Build Coastguard Worker):
234*da0073e9SAndroid Build Coastguard Worker    def _error(msg):
235*da0073e9SAndroid Build Coastguard Worker        return _rendezvous_error("env:// rendezvous: " + msg)
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker    def _env_error(var):
238*da0073e9SAndroid Build Coastguard Worker        return _error(f"environment variable {var} expected, but not set")
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker    def _get_env_or_raise(env_var: str) -> str:
241*da0073e9SAndroid Build Coastguard Worker        env_val = os.environ.get(env_var, None)
242*da0073e9SAndroid Build Coastguard Worker        if not env_val:
243*da0073e9SAndroid Build Coastguard Worker            raise _env_error(env_var)
244*da0073e9SAndroid Build Coastguard Worker        else:
245*da0073e9SAndroid Build Coastguard Worker            return env_val
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker    result = urlparse(url)
248*da0073e9SAndroid Build Coastguard Worker    query_dict = _query_to_dict(result.query)
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker    rank: int
251*da0073e9SAndroid Build Coastguard Worker    world_size: int
252*da0073e9SAndroid Build Coastguard Worker    master_port: int
253*da0073e9SAndroid Build Coastguard Worker    master_addr: str
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker    if "rank" in query_dict:
256*da0073e9SAndroid Build Coastguard Worker        rank = int(query_dict["rank"])
257*da0073e9SAndroid Build Coastguard Worker    else:
258*da0073e9SAndroid Build Coastguard Worker        rank = int(_get_env_or_raise("RANK"))
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker    if "world_size" in query_dict:
261*da0073e9SAndroid Build Coastguard Worker        world_size = int(query_dict["world_size"])
262*da0073e9SAndroid Build Coastguard Worker    else:
263*da0073e9SAndroid Build Coastguard Worker        world_size = int(_get_env_or_raise("WORLD_SIZE"))
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker    master_addr = _get_env_or_raise("MASTER_ADDR")
266*da0073e9SAndroid Build Coastguard Worker    master_port = int(_get_env_or_raise("MASTER_PORT"))
267*da0073e9SAndroid Build Coastguard Worker    use_libuv = _get_use_libuv_from_query_dict(query_dict)
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker    store = _create_c10d_store(
270*da0073e9SAndroid Build Coastguard Worker        master_addr, master_port, rank, world_size, timeout, use_libuv
271*da0073e9SAndroid Build Coastguard Worker    )
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker    yield (store, rank, world_size)
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker    # If this configuration is invalidated, there is nothing we can do about it
276*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError("Unable to perform re-rendezvous using env:// method")
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Workerregister_rendezvous_handler("tcp", _tcp_rendezvous_handler)
280*da0073e9SAndroid Build Coastguard Workerregister_rendezvous_handler("env", _env_rendezvous_handler)
281*da0073e9SAndroid Build Coastguard Workerregister_rendezvous_handler("file", _file_rendezvous_handler)
282