1# mypy: allow-untyped-defs 2# Copyright (c) Facebook, Inc. and its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7import abc 8import logging 9import threading 10import time 11from contextlib import contextmanager 12from inspect import getframeinfo, stack 13from typing import Any, Dict, List, Optional, Set 14 15 16__all__ = [ 17 "TimerRequest", 18 "TimerClient", 19 "RequestQueue", 20 "TimerServer", 21 "configure", 22 "expires", 23] 24 25logger = logging.getLogger(__name__) 26 27 28class TimerRequest: 29 """ 30 Data object representing a countdown timer acquisition and release 31 that is used between the ``TimerClient`` and ``TimerServer``. 32 A negative ``expiration_time`` should be interpreted as a "release" 33 request. 34 35 .. note:: the type of ``worker_id`` is implementation specific. 36 It is whatever the TimerServer and TimerClient implementations 37 have on to uniquely identify a worker. 38 """ 39 40 __slots__ = ["worker_id", "scope_id", "expiration_time"] 41 42 def __init__(self, worker_id: Any, scope_id: str, expiration_time: float): 43 self.worker_id = worker_id 44 self.scope_id = scope_id 45 self.expiration_time = expiration_time 46 47 def __eq__(self, other): 48 if isinstance(other, TimerRequest): 49 return ( 50 self.worker_id == other.worker_id 51 and self.scope_id == other.scope_id 52 and self.expiration_time == other.expiration_time 53 ) 54 return False 55 56 57class TimerClient(abc.ABC): 58 """ 59 Client library to acquire and release countdown timers by communicating 60 with the TimerServer. 61 """ 62 63 @abc.abstractmethod 64 def acquire(self, scope_id: str, expiration_time: float) -> None: 65 """ 66 Acquires a timer for the worker that holds this client object 67 given the scope_id and expiration_time. Typically registers 68 the timer with the TimerServer. 69 """ 70 71 @abc.abstractmethod 72 def release(self, scope_id: str): 73 """ 74 Releases the timer for the ``scope_id`` on the worker this 75 client represents. After this method is 76 called, the countdown timer on the scope is no longer in effect. 77 """ 78 79 80class RequestQueue(abc.ABC): 81 """ 82 Consumer queue holding timer acquisition/release requests 83 """ 84 85 @abc.abstractmethod 86 def size(self) -> int: 87 """ 88 Returns the size of the queue at the time this method is called. 89 Note that by the time ``get`` is called the size of the queue 90 may have increased. The size of the queue should not decrease 91 until the ``get`` method is called. That is, the following assertion 92 should hold: 93 94 size = q.size() 95 res = q.get(size, timeout=0) 96 assert size == len(res) 97 98 -- or -- 99 100 size = q.size() 101 res = q.get(size * 2, timeout=1) 102 assert size <= len(res) <= size * 2 103 """ 104 105 @abc.abstractmethod 106 def get(self, size: int, timeout: float) -> List[TimerRequest]: 107 """ 108 Gets up to ``size`` number of timer requests in a blocking fashion 109 (no more than ``timeout`` seconds). 110 """ 111 112 113class TimerServer(abc.ABC): 114 """ 115 Entity that monitors active timers and expires them 116 in a timely fashion. This server is responsible for 117 reaping workers that have expired timers. 118 """ 119 120 def __init__( 121 self, request_queue: RequestQueue, max_interval: float, daemon: bool = True 122 ): 123 """ 124 :param request_queue: Consumer ``RequestQueue`` 125 :param max_interval: max time (in seconds) to wait 126 for an item in the request_queue 127 :param daemon: whether to run the watchdog thread as a daemon 128 """ 129 super().__init__() 130 self._request_queue = request_queue 131 self._max_interval = max_interval 132 self._daemon = daemon 133 self._watchdog_thread: Optional[threading.Thread] = None 134 self._stop_signaled = False 135 136 @abc.abstractmethod 137 def register_timers(self, timer_requests: List[TimerRequest]) -> None: 138 """ 139 Processes the incoming timer requests and registers them with the server. 140 The timer request can either be a acquire-timer or release-timer request. 141 Timer requests with a negative expiration_time should be interpreted 142 as a release-timer request. 143 """ 144 145 @abc.abstractmethod 146 def clear_timers(self, worker_ids: Set[Any]) -> None: 147 """ 148 Clears all timers for the given ``worker_ids``. 149 """ 150 151 @abc.abstractmethod 152 def get_expired_timers(self, deadline: float) -> Dict[str, List[TimerRequest]]: 153 """ 154 Returns all expired timers for each worker_id. An expired timer 155 is a timer for which the expiration_time is less than or equal to 156 the provided deadline. 157 """ 158 159 @abc.abstractmethod 160 def _reap_worker(self, worker_id: Any) -> bool: 161 """ 162 Reaps the given worker. Returns True if the worker has been 163 successfully reaped, False otherwise. If any uncaught exception 164 is thrown from this method, the worker is considered reaped 165 and all associated timers will be removed. 166 """ 167 168 def _reap_worker_no_throw(self, worker_id: Any) -> bool: 169 """ 170 Wraps ``_reap_worker(worker_id)``, if an uncaught exception is 171 thrown, then it considers the worker as reaped. 172 """ 173 try: 174 return self._reap_worker(worker_id) 175 except Exception: 176 logger.exception( 177 "Uncaught exception thrown from _reap_worker(), " 178 "check that the implementation correctly catches exceptions", 179 ) 180 return True 181 182 def _watchdog_loop(self): 183 while not self._stop_signaled: 184 try: 185 self._run_watchdog() 186 except Exception: 187 logger.exception("Error running watchdog") 188 189 def _run_watchdog(self): 190 batch_size = max(1, self._request_queue.size()) 191 timer_requests = self._request_queue.get(batch_size, self._max_interval) 192 self.register_timers(timer_requests) 193 now = time.time() 194 reaped_worker_ids = set() 195 for worker_id, expired_timers in self.get_expired_timers(now).items(): 196 logger.info( 197 "Reaping worker_id=[%s]." " Expired timers: %s", 198 worker_id, 199 self._get_scopes(expired_timers), 200 ) 201 if self._reap_worker_no_throw(worker_id): 202 logger.info("Successfully reaped worker=[%s]", worker_id) 203 reaped_worker_ids.add(worker_id) 204 else: 205 logger.error( 206 "Error reaping worker=[%s]. Will retry on next watchdog.", worker_id 207 ) 208 self.clear_timers(reaped_worker_ids) 209 210 def _get_scopes(self, timer_requests): 211 return [r.scope_id for r in timer_requests] 212 213 def start(self) -> None: 214 logger.info( 215 "Starting %s..." " max_interval=%s," " daemon=%s", 216 type(self).__name__, 217 self._max_interval, 218 self._daemon, 219 ) 220 self._watchdog_thread = threading.Thread( 221 target=self._watchdog_loop, daemon=self._daemon 222 ) 223 logger.info("Starting watchdog thread...") 224 self._watchdog_thread.start() 225 226 def stop(self) -> None: 227 logger.info("Stopping %s", type(self).__name__) 228 self._stop_signaled = True 229 if self._watchdog_thread: 230 logger.info("Stopping watchdog thread...") 231 self._watchdog_thread.join(self._max_interval) 232 self._watchdog_thread = None 233 else: 234 logger.info("No watchdog thread running, doing nothing") 235 236 237_timer_client: Optional[TimerClient] = None 238 239 240def configure(timer_client: TimerClient): 241 """ 242 Configures a timer client. Must be called before using ``expires``. 243 """ 244 global _timer_client 245 _timer_client = timer_client 246 logger.info("Timer client configured to: %s", type(_timer_client).__name__) 247 248 249@contextmanager 250def expires( 251 after: float, scope: Optional[str] = None, client: Optional[TimerClient] = None 252): 253 """ 254 Acquires a countdown timer that expires in ``after`` seconds from now, 255 unless the code-block that it wraps is finished within the timeframe. 256 When the timer expires, this worker is eligible to be reaped. The 257 exact meaning of "reaped" depends on the client implementation. In 258 most cases, reaping means to terminate the worker process. 259 Note that the worker is NOT guaranteed to be reaped at exactly 260 ``time.now() + after``, but rather the worker is "eligible" for being 261 reaped and the ``TimerServer`` that the client talks to will ultimately 262 make the decision when and how to reap the workers with expired timers. 263 264 Usage:: 265 266 torch.distributed.elastic.timer.configure(LocalTimerClient()) 267 with expires(after=10): 268 torch.distributed.all_reduce(...) 269 """ 270 if client is None: 271 if _timer_client is None: 272 raise RuntimeError("Configure timer client before using countdown timers.") 273 client = _timer_client 274 if scope is None: 275 # grab the caller file + lineno 276 caller = getframeinfo(stack()[1][0]) 277 scope = f"{caller.filename}#{caller.lineno}" 278 expiration = time.time() + after 279 client.acquire(scope, expiration) 280 try: 281 yield 282 finally: 283 client.release(scope) 284