xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/timer/api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Facebook, Inc. and its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7import abc
8import logging
9import threading
10import time
11from contextlib import contextmanager
12from inspect import getframeinfo, stack
13from typing import Any, Dict, List, Optional, Set
14
15
16__all__ = [
17    "TimerRequest",
18    "TimerClient",
19    "RequestQueue",
20    "TimerServer",
21    "configure",
22    "expires",
23]
24
25logger = logging.getLogger(__name__)
26
27
28class TimerRequest:
29    """
30    Data object representing a countdown timer acquisition and release
31    that is used between the ``TimerClient`` and ``TimerServer``.
32    A negative ``expiration_time`` should be interpreted as a "release"
33    request.
34
35    .. note:: the type of ``worker_id`` is implementation specific.
36              It is whatever the TimerServer and TimerClient implementations
37              have on to uniquely identify a worker.
38    """
39
40    __slots__ = ["worker_id", "scope_id", "expiration_time"]
41
42    def __init__(self, worker_id: Any, scope_id: str, expiration_time: float):
43        self.worker_id = worker_id
44        self.scope_id = scope_id
45        self.expiration_time = expiration_time
46
47    def __eq__(self, other):
48        if isinstance(other, TimerRequest):
49            return (
50                self.worker_id == other.worker_id
51                and self.scope_id == other.scope_id
52                and self.expiration_time == other.expiration_time
53            )
54        return False
55
56
57class TimerClient(abc.ABC):
58    """
59    Client library to acquire and release countdown timers by communicating
60    with the TimerServer.
61    """
62
63    @abc.abstractmethod
64    def acquire(self, scope_id: str, expiration_time: float) -> None:
65        """
66        Acquires a timer for the worker that holds this client object
67        given the scope_id and expiration_time. Typically registers
68        the timer with the TimerServer.
69        """
70
71    @abc.abstractmethod
72    def release(self, scope_id: str):
73        """
74        Releases the timer for the ``scope_id`` on the worker this
75        client represents. After this method is
76        called, the countdown timer on the scope is no longer in effect.
77        """
78
79
80class RequestQueue(abc.ABC):
81    """
82    Consumer queue holding timer acquisition/release requests
83    """
84
85    @abc.abstractmethod
86    def size(self) -> int:
87        """
88        Returns the size of the queue at the time this method is called.
89        Note that by the time ``get`` is called the size of the queue
90        may have increased. The size of the queue should not decrease
91        until the ``get`` method is called. That is, the following assertion
92        should hold:
93
94        size = q.size()
95        res = q.get(size, timeout=0)
96        assert size == len(res)
97
98        -- or --
99
100        size = q.size()
101        res = q.get(size * 2, timeout=1)
102        assert size <= len(res) <= size * 2
103        """
104
105    @abc.abstractmethod
106    def get(self, size: int, timeout: float) -> List[TimerRequest]:
107        """
108        Gets up to ``size`` number of timer requests in a blocking fashion
109        (no more than ``timeout`` seconds).
110        """
111
112
113class TimerServer(abc.ABC):
114    """
115    Entity that monitors active timers and expires them
116    in a timely fashion. This server is responsible for
117    reaping workers that have expired timers.
118    """
119
120    def __init__(
121        self, request_queue: RequestQueue, max_interval: float, daemon: bool = True
122    ):
123        """
124        :param request_queue: Consumer ``RequestQueue``
125        :param max_interval: max time (in seconds) to wait
126                             for an item in the request_queue
127        :param daemon: whether to run the watchdog thread as a daemon
128        """
129        super().__init__()
130        self._request_queue = request_queue
131        self._max_interval = max_interval
132        self._daemon = daemon
133        self._watchdog_thread: Optional[threading.Thread] = None
134        self._stop_signaled = False
135
136    @abc.abstractmethod
137    def register_timers(self, timer_requests: List[TimerRequest]) -> None:
138        """
139        Processes the incoming timer requests and registers them with the server.
140        The timer request can either be a acquire-timer or release-timer request.
141        Timer requests with a negative expiration_time should be interpreted
142        as a release-timer request.
143        """
144
145    @abc.abstractmethod
146    def clear_timers(self, worker_ids: Set[Any]) -> None:
147        """
148        Clears all timers for the given ``worker_ids``.
149        """
150
151    @abc.abstractmethod
152    def get_expired_timers(self, deadline: float) -> Dict[str, List[TimerRequest]]:
153        """
154        Returns all expired timers for each worker_id. An expired timer
155        is a timer for which the expiration_time is less than or equal to
156        the provided deadline.
157        """
158
159    @abc.abstractmethod
160    def _reap_worker(self, worker_id: Any) -> bool:
161        """
162        Reaps the given worker. Returns True if the worker has been
163        successfully reaped, False otherwise. If any uncaught exception
164        is thrown from this method, the worker is considered reaped
165        and all associated timers will be removed.
166        """
167
168    def _reap_worker_no_throw(self, worker_id: Any) -> bool:
169        """
170        Wraps ``_reap_worker(worker_id)``, if an uncaught exception is
171        thrown, then it considers the worker as reaped.
172        """
173        try:
174            return self._reap_worker(worker_id)
175        except Exception:
176            logger.exception(
177                "Uncaught exception thrown from _reap_worker(), "
178                "check that the implementation correctly catches exceptions",
179            )
180            return True
181
182    def _watchdog_loop(self):
183        while not self._stop_signaled:
184            try:
185                self._run_watchdog()
186            except Exception:
187                logger.exception("Error running watchdog")
188
189    def _run_watchdog(self):
190        batch_size = max(1, self._request_queue.size())
191        timer_requests = self._request_queue.get(batch_size, self._max_interval)
192        self.register_timers(timer_requests)
193        now = time.time()
194        reaped_worker_ids = set()
195        for worker_id, expired_timers in self.get_expired_timers(now).items():
196            logger.info(
197                "Reaping worker_id=[%s]." " Expired timers: %s",
198                worker_id,
199                self._get_scopes(expired_timers),
200            )
201            if self._reap_worker_no_throw(worker_id):
202                logger.info("Successfully reaped worker=[%s]", worker_id)
203                reaped_worker_ids.add(worker_id)
204            else:
205                logger.error(
206                    "Error reaping worker=[%s]. Will retry on next watchdog.", worker_id
207                )
208        self.clear_timers(reaped_worker_ids)
209
210    def _get_scopes(self, timer_requests):
211        return [r.scope_id for r in timer_requests]
212
213    def start(self) -> None:
214        logger.info(
215            "Starting %s..." " max_interval=%s," " daemon=%s",
216            type(self).__name__,
217            self._max_interval,
218            self._daemon,
219        )
220        self._watchdog_thread = threading.Thread(
221            target=self._watchdog_loop, daemon=self._daemon
222        )
223        logger.info("Starting watchdog thread...")
224        self._watchdog_thread.start()
225
226    def stop(self) -> None:
227        logger.info("Stopping %s", type(self).__name__)
228        self._stop_signaled = True
229        if self._watchdog_thread:
230            logger.info("Stopping watchdog thread...")
231            self._watchdog_thread.join(self._max_interval)
232            self._watchdog_thread = None
233        else:
234            logger.info("No watchdog thread running, doing nothing")
235
236
237_timer_client: Optional[TimerClient] = None
238
239
240def configure(timer_client: TimerClient):
241    """
242    Configures a timer client. Must be called before using ``expires``.
243    """
244    global _timer_client
245    _timer_client = timer_client
246    logger.info("Timer client configured to: %s", type(_timer_client).__name__)
247
248
249@contextmanager
250def expires(
251    after: float, scope: Optional[str] = None, client: Optional[TimerClient] = None
252):
253    """
254    Acquires a countdown timer that expires in ``after`` seconds from now,
255    unless the code-block that it wraps is finished within the timeframe.
256    When the timer expires, this worker is eligible to be reaped. The
257    exact meaning of "reaped" depends on the client implementation. In
258    most cases, reaping means to terminate the worker process.
259    Note that the worker is NOT guaranteed to be reaped at exactly
260    ``time.now() + after``, but rather the worker is "eligible" for being
261    reaped and the ``TimerServer`` that the client talks to will ultimately
262    make the decision when and how to reap the workers with expired timers.
263
264    Usage::
265
266        torch.distributed.elastic.timer.configure(LocalTimerClient())
267        with expires(after=10):
268            torch.distributed.all_reduce(...)
269    """
270    if client is None:
271        if _timer_client is None:
272            raise RuntimeError("Configure timer client before using countdown timers.")
273        client = _timer_client
274    if scope is None:
275        # grab the caller file + lineno
276        caller = getframeinfo(stack()[1][0])
277        scope = f"{caller.filename}#{caller.lineno}"
278    expiration = time.time() + after
279    client.acquire(scope, expiration)
280    try:
281        yield
282    finally:
283        client.release(scope)
284