xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/rendezvous/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Facebook, Inc. and its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import ipaddress
9import random
10import re
11import socket
12import time
13import weakref
14from datetime import timedelta
15from threading import Event, Thread
16from typing import Any, Callable, Dict, Optional, Tuple, Union
17
18
19__all__ = ["parse_rendezvous_endpoint"]
20
21
22def _parse_rendezvous_config(config_str: str) -> Dict[str, str]:
23    """Extract key-value pairs from a rendezvous configuration string.
24
25    Args:
26        config_str:
27            A string in format <key1>=<value1>,...,<keyN>=<valueN>.
28    """
29    config: Dict[str, str] = {}
30
31    config_str = config_str.strip()
32    if not config_str:
33        return config
34
35    key_values = config_str.split(",")
36    for kv in key_values:
37        key, *values = kv.split("=", 1)
38
39        key = key.strip()
40        if not key:
41            raise ValueError(
42                "The rendezvous configuration string must be in format "
43                "<key1>=<value1>,...,<keyN>=<valueN>."
44            )
45
46        value: Optional[str]
47        if values:
48            value = values[0].strip()
49        else:
50            value = None
51        if not value:
52            raise ValueError(
53                f"The rendezvous configuration option '{key}' must have a value specified."
54            )
55
56        config[key] = value
57    return config
58
59
60def _try_parse_port(port_str: str) -> Optional[int]:
61    """Try to extract the port number from ``port_str``."""
62    if port_str and re.match(r"^[0-9]{1,5}$", port_str):
63        return int(port_str)
64    return None
65
66
67def parse_rendezvous_endpoint(
68    endpoint: Optional[str], default_port: int
69) -> Tuple[str, int]:
70    """Extract the hostname and the port number from a rendezvous endpoint.
71
72    Args:
73        endpoint:
74            A string in format <hostname>[:<port>].
75        default_port:
76            The port number to use if the endpoint does not include one.
77
78    Returns:
79        A tuple of hostname and port number.
80    """
81    if endpoint is not None:
82        endpoint = endpoint.strip()
83
84    if not endpoint:
85        return ("localhost", default_port)
86
87    # An endpoint that starts and ends with brackets represents an IPv6 address.
88    if endpoint[0] == "[" and endpoint[-1] == "]":
89        host, *rest = endpoint, *[]
90    else:
91        host, *rest = endpoint.rsplit(":", 1)
92
93    # Sanitize the IPv6 address.
94    if len(host) > 1 and host[0] == "[" and host[-1] == "]":
95        host = host[1:-1]
96
97    if len(rest) == 1:
98        port = _try_parse_port(rest[0])
99        if port is None or port >= 2**16:
100            raise ValueError(
101                f"The port number of the rendezvous endpoint '{endpoint}' must be an integer "
102                "between 0 and 65536."
103            )
104    else:
105        port = default_port
106
107    if not re.match(r"^[\w\.:-]+$", host):
108        raise ValueError(
109            f"The hostname of the rendezvous endpoint '{endpoint}' must be a dot-separated list of "
110            "labels, an IPv4 address, or an IPv6 address."
111        )
112
113    return host, port
114
115
116def _matches_machine_hostname(host: str) -> bool:
117    """Indicate whether ``host`` matches the hostname of this machine.
118
119    This function compares ``host`` to the hostname as well as to the IP
120    addresses of this machine. Note that it may return a false negative if this
121    machine has CNAME records beyond its FQDN or IP addresses assigned to
122    secondary NICs.
123    """
124    if host == "localhost":
125        return True
126
127    try:
128        addr = ipaddress.ip_address(host)
129    except ValueError:
130        addr = None
131
132    if addr and addr.is_loopback:
133        return True
134
135    try:
136        host_addr_list = socket.getaddrinfo(
137            host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME
138        )
139    except (ValueError, socket.gaierror) as _:
140        host_addr_list = []
141
142    host_ip_list = [host_addr_info[4][0] for host_addr_info in host_addr_list]
143
144    this_host = socket.gethostname()
145    if host == this_host:
146        return True
147
148    addr_list = socket.getaddrinfo(
149        this_host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME
150    )
151    for addr_info in addr_list:
152        # If we have an FQDN in the addr_info, compare it to `host`.
153        if addr_info[3] and addr_info[3] == host:
154            return True
155
156        # Otherwise if `host` represents an IP address, compare it to our IP
157        # address.
158        if addr and addr_info[4][0] == str(addr):
159            return True
160
161        # If the IP address matches one of the provided host's IP addresses
162        if addr_info[4][0] in host_ip_list:
163            return True
164
165    return False
166
167
168def _delay(seconds: Union[float, Tuple[float, float]]) -> None:
169    """Suspend the current thread for ``seconds``.
170
171    Args:
172        seconds:
173            Either the delay, in seconds, or a tuple of a lower and an upper
174            bound within which a random delay will be picked.
175    """
176    if isinstance(seconds, tuple):
177        seconds = random.uniform(*seconds)
178    # Ignore delay requests that are less than 10 milliseconds.
179    if seconds >= 0.01:
180        time.sleep(seconds)
181
182
183class _PeriodicTimer:
184    """Represent a timer that periodically runs a specified function.
185
186    Args:
187        interval:
188            The interval, in seconds, between each run.
189        function:
190            The function to run.
191    """
192
193    # The state of the timer is hold in a separate context object to avoid a
194    # reference cycle between the timer and the background thread.
195    class _Context:
196        interval: float
197        function: Callable[..., None]
198        args: Tuple[Any, ...]
199        kwargs: Dict[str, Any]
200        stop_event: Event
201
202    _name: Optional[str]
203    _thread: Optional[Thread]
204    _finalizer: Optional[weakref.finalize]
205
206    # The context that is shared between the timer and the background thread.
207    _ctx: _Context
208
209    def __init__(
210        self,
211        interval: timedelta,
212        function: Callable[..., None],
213        *args: Any,
214        **kwargs: Any,
215    ) -> None:
216        self._name = None
217
218        self._ctx = self._Context()
219        self._ctx.interval = interval.total_seconds()
220        self._ctx.function = function  # type: ignore[assignment]
221        self._ctx.args = args or ()
222        self._ctx.kwargs = kwargs or {}
223        self._ctx.stop_event = Event()
224
225        self._thread = None
226        self._finalizer = None
227
228    @property
229    def name(self) -> Optional[str]:
230        """Get the name of the timer."""
231        return self._name
232
233    def set_name(self, name: str) -> None:
234        """Set the name of the timer.
235
236        The specified name will be assigned to the background thread and serves
237        for debugging and troubleshooting purposes.
238        """
239        if self._thread:
240            raise RuntimeError("The timer has already started.")
241
242        self._name = name
243
244    def start(self) -> None:
245        """Start the timer."""
246        if self._thread:
247            raise RuntimeError("The timer has already started.")
248
249        self._thread = Thread(
250            target=self._run,
251            name=self._name or "PeriodicTimer",
252            args=(self._ctx,),
253            daemon=True,
254        )
255
256        # We avoid using a regular finalizer (a.k.a. __del__) for stopping the
257        # timer as joining a daemon thread during the interpreter shutdown can
258        # cause deadlocks. The weakref.finalize is a superior alternative that
259        # provides a consistent behavior regardless of the GC implementation.
260        self._finalizer = weakref.finalize(
261            self, self._stop_thread, self._thread, self._ctx.stop_event
262        )
263
264        # We do not attempt to stop our background thread during the interpreter
265        # shutdown. At that point we do not even know whether it still exists.
266        self._finalizer.atexit = False
267
268        self._thread.start()
269
270    def cancel(self) -> None:
271        """Stop the timer at the next opportunity."""
272        if self._finalizer:
273            self._finalizer()
274
275    @staticmethod
276    def _run(ctx) -> None:
277        while not ctx.stop_event.wait(ctx.interval):
278            ctx.function(*ctx.args, **ctx.kwargs)
279
280    @staticmethod
281    def _stop_thread(thread, stop_event):
282        stop_event.set()
283
284        thread.join()
285