1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18from __future__ import annotations 19import asyncio 20import collections 21import enum 22import functools 23import logging 24import sys 25import warnings 26from typing import ( 27 Awaitable, 28 Set, 29 TypeVar, 30 List, 31 Tuple, 32 Callable, 33 Any, 34 Optional, 35 Union, 36 overload, 37) 38 39from pyee import EventEmitter 40 41from .colors import color 42 43# ----------------------------------------------------------------------------- 44# Logging 45# ----------------------------------------------------------------------------- 46logger = logging.getLogger(__name__) 47 48 49# ----------------------------------------------------------------------------- 50def setup_event_forwarding(emitter, forwarder, event_name): 51 def emit(*args, **kwargs): 52 forwarder.emit(event_name, *args, **kwargs) 53 54 emitter.on(event_name, emit) 55 56 57# ----------------------------------------------------------------------------- 58def composite_listener(cls): 59 """ 60 Decorator that adds a `register` and `deregister` method to a class, which 61 registers/deregisters all methods named `on_<event_name>` as a listener for 62 the <event_name> event with an emitter. 63 """ 64 # pylint: disable=protected-access 65 66 def register(self, emitter): 67 for method_name in dir(cls): 68 if method_name.startswith('on_'): 69 emitter.on(method_name[3:], getattr(self, method_name)) 70 71 def deregister(self, emitter): 72 for method_name in dir(cls): 73 if method_name.startswith('on_'): 74 emitter.remove_listener(method_name[3:], getattr(self, method_name)) 75 76 cls._bumble_register_composite = register 77 cls._bumble_deregister_composite = deregister 78 return cls 79 80 81# ----------------------------------------------------------------------------- 82_Handler = TypeVar('_Handler', bound=Callable) 83 84 85class EventWatcher: 86 '''A wrapper class to control the lifecycle of event handlers better. 87 88 Usage: 89 ``` 90 watcher = EventWatcher() 91 92 def on_foo(): 93 ... 94 watcher.on(emitter, 'foo', on_foo) 95 96 @watcher.on(emitter, 'bar') 97 def on_bar(): 98 ... 99 100 # Close all event handlers watching through this watcher 101 watcher.close() 102 ``` 103 104 As context: 105 ``` 106 with contextlib.closing(EventWatcher()) as context: 107 @context.on(emitter, 'foo') 108 def on_foo(): 109 ... 110 # on_foo() has been removed here! 111 ``` 112 ''' 113 114 handlers: List[Tuple[EventEmitter, str, Callable[..., Any]]] 115 116 def __init__(self) -> None: 117 self.handlers = [] 118 119 @overload 120 def on( 121 self, emitter: EventEmitter, event: str 122 ) -> Callable[[_Handler], _Handler]: ... 123 124 @overload 125 def on(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler: ... 126 127 def on( 128 self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None 129 ) -> Union[_Handler, Callable[[_Handler], _Handler]]: 130 '''Watch an event until the context is closed. 131 132 Args: 133 emitter: EventEmitter to watch 134 event: Event name 135 handler: (Optional) Event handler. When nothing is passed, this method 136 works as a decorator. 137 ''' 138 139 def wrapper(wrapped: _Handler) -> _Handler: 140 self.handlers.append((emitter, event, wrapped)) 141 emitter.on(event, wrapped) 142 return wrapped 143 144 return wrapper if handler is None else wrapper(handler) 145 146 @overload 147 def once( 148 self, emitter: EventEmitter, event: str 149 ) -> Callable[[_Handler], _Handler]: ... 150 151 @overload 152 def once( 153 self, emitter: EventEmitter, event: str, handler: _Handler 154 ) -> _Handler: ... 155 156 def once( 157 self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None 158 ) -> Union[_Handler, Callable[[_Handler], _Handler]]: 159 '''Watch an event for once. 160 161 Args: 162 emitter: EventEmitter to watch 163 event: Event name 164 handler: (Optional) Event handler. When nothing passed, this method works 165 as a decorator. 166 ''' 167 168 def wrapper(wrapped: _Handler) -> _Handler: 169 self.handlers.append((emitter, event, wrapped)) 170 emitter.once(event, wrapped) 171 return wrapped 172 173 return wrapper if handler is None else wrapper(handler) 174 175 def close(self) -> None: 176 for emitter, event, handler in self.handlers: 177 if handler in emitter.listeners(event): 178 emitter.remove_listener(event, handler) 179 180 181# ----------------------------------------------------------------------------- 182_T = TypeVar('_T') 183 184 185class AbortableEventEmitter(EventEmitter): 186 def abort_on(self, event: str, awaitable: Awaitable[_T]) -> Awaitable[_T]: 187 """ 188 Set a coroutine or future to abort when an event occur. 189 """ 190 future = asyncio.ensure_future(awaitable) 191 if future.done(): 192 return future 193 194 def on_event(*_): 195 if future.done(): 196 return 197 msg = f'abort: {event} event occurred.' 198 if isinstance(future, asyncio.Task): 199 # python < 3.9 does not support passing a message on `Task.cancel` 200 if sys.version_info < (3, 9, 0): 201 future.cancel() 202 else: 203 future.cancel(msg) 204 else: 205 future.set_exception(asyncio.CancelledError(msg)) 206 207 def on_done(_): 208 self.remove_listener(event, on_event) 209 210 self.on(event, on_event) 211 future.add_done_callback(on_done) 212 return future 213 214 215# ----------------------------------------------------------------------------- 216class CompositeEventEmitter(AbortableEventEmitter): 217 def __init__(self): 218 super().__init__() 219 self._listener = None 220 221 @property 222 def listener(self): 223 return self._listener 224 225 @listener.setter 226 def listener(self, listener): 227 # pylint: disable=protected-access 228 if self._listener: 229 # Call the deregistration methods for each base class that has them 230 for cls in self._listener.__class__.mro(): 231 if '_bumble_register_composite' in cls.__dict__: 232 cls._bumble_deregister_composite(self._listener, self) 233 self._listener = listener 234 if listener: 235 # Call the registration methods for each base class that has them 236 for cls in listener.__class__.mro(): 237 if '_bumble_deregister_composite' in cls.__dict__: 238 cls._bumble_register_composite(listener, self) 239 240 241# ----------------------------------------------------------------------------- 242class AsyncRunner: 243 class WorkQueue: 244 def __init__(self, create_task=True): 245 self.queue = None 246 self.task = None 247 self.create_task = create_task 248 249 def enqueue(self, coroutine): 250 # Create a task now if we need to and haven't done so already 251 if self.create_task and self.task is None: 252 self.task = asyncio.create_task(self.run()) 253 254 # Lazy-create the coroutine queue 255 if self.queue is None: 256 self.queue = asyncio.Queue() 257 258 # Enqueue the work 259 self.queue.put_nowait(coroutine) 260 261 async def run(self): 262 while True: 263 item = await self.queue.get() 264 try: 265 await item 266 except Exception as error: 267 logger.warning( 268 f'{color("!!! Exception in work queue:", "red")} {error}' 269 ) 270 271 # Shared default queue 272 default_queue = WorkQueue() 273 274 # Shared set of running tasks 275 running_tasks: Set[Awaitable] = set() 276 277 @staticmethod 278 def run_in_task(queue=None): 279 """ 280 Function decorator used to adapt an async function into a sync function 281 """ 282 283 def decorator(func): 284 @functools.wraps(func) 285 def wrapper(*args, **kwargs): 286 coroutine = func(*args, **kwargs) 287 if queue is None: 288 # Spawn the coroutine as a task 289 async def run(): 290 try: 291 await coroutine 292 except Exception: 293 logger.exception(color("!!! Exception in wrapper:", "red")) 294 295 AsyncRunner.spawn(run()) 296 else: 297 # Queue the coroutine to be awaited by the work queue 298 queue.enqueue(coroutine) 299 300 return wrapper 301 302 return decorator 303 304 @staticmethod 305 def spawn(coroutine): 306 """ 307 Spawn a task to run a coroutine in a "fire and forget" mode. 308 309 Using this method instead of just calling `asyncio.create_task(coroutine)` 310 is necessary when you don't keep a reference to the task, because `asyncio` 311 only keeps weak references to alive tasks. 312 """ 313 task = asyncio.create_task(coroutine) 314 AsyncRunner.running_tasks.add(task) 315 task.add_done_callback(AsyncRunner.running_tasks.remove) 316 317 318# ----------------------------------------------------------------------------- 319class FlowControlAsyncPipe: 320 """ 321 Asyncio pipe with flow control. When writing to the pipe, the source is 322 paused (by calling a function passed in when the pipe is created) if the 323 amount of queued data exceeds a specified threshold. 324 """ 325 326 def __init__( 327 self, 328 pause_source, 329 resume_source, 330 write_to_sink=None, 331 drain_sink=None, 332 threshold=0, 333 ): 334 self.pause_source = pause_source 335 self.resume_source = resume_source 336 self.write_to_sink = write_to_sink 337 self.drain_sink = drain_sink 338 self.threshold = threshold 339 self.queue = collections.deque() # Queue of packets 340 self.queued_bytes = 0 # Number of bytes in the queue 341 self.ready_to_pump = asyncio.Event() 342 self.paused = False 343 self.source_paused = False 344 self.pump_task = None 345 346 def start(self): 347 if self.pump_task is None: 348 self.pump_task = asyncio.create_task(self.pump()) 349 350 self.check_pump() 351 352 def stop(self): 353 if self.pump_task is not None: 354 self.pump_task.cancel() 355 self.pump_task = None 356 357 def write(self, packet): 358 self.queued_bytes += len(packet) 359 self.queue.append(packet) 360 361 # Pause the source if we're over the threshold 362 if self.queued_bytes > self.threshold and not self.source_paused: 363 logger.debug(f'pausing source (queued={self.queued_bytes})') 364 self.pause_source() 365 self.source_paused = True 366 367 self.check_pump() 368 369 def pause(self): 370 if not self.paused: 371 self.paused = True 372 if not self.source_paused: 373 self.pause_source() 374 self.source_paused = True 375 self.check_pump() 376 377 def resume(self): 378 if self.paused: 379 self.paused = False 380 if self.source_paused: 381 self.resume_source() 382 self.source_paused = False 383 self.check_pump() 384 385 def can_pump(self): 386 return self.queue and not self.paused and self.write_to_sink is not None 387 388 def check_pump(self): 389 if self.can_pump(): 390 self.ready_to_pump.set() 391 else: 392 self.ready_to_pump.clear() 393 394 async def pump(self): 395 while True: 396 # Wait until we can try to pump packets 397 await self.ready_to_pump.wait() 398 399 # Try to pump a packet 400 if self.can_pump(): 401 packet = self.queue.pop() 402 self.write_to_sink(packet) 403 self.queued_bytes -= len(packet) 404 405 # Drain the sink if we can 406 if self.drain_sink: 407 await self.drain_sink() 408 409 # Check if we can accept more 410 if self.queued_bytes <= self.threshold and self.source_paused: 411 logger.debug(f'resuming source (queued={self.queued_bytes})') 412 self.source_paused = False 413 self.resume_source() 414 415 self.check_pump() 416 417 418# ----------------------------------------------------------------------------- 419async def async_call(function, *args, **kwargs): 420 """ 421 Immediately calls the function with provided args and kwargs, wrapping it in an 422 async function. 423 Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject 424 a running loop. 425 426 result = await async_call(some_function, ...) 427 """ 428 return function(*args, **kwargs) 429 430 431# ----------------------------------------------------------------------------- 432def wrap_async(function): 433 """ 434 Wraps the provided function in an async function. 435 """ 436 return functools.partial(async_call, function) 437 438 439# ----------------------------------------------------------------------------- 440def deprecated(msg: str): 441 """ 442 Throw deprecation warning before execution. 443 """ 444 445 def wrapper(function): 446 @functools.wraps(function) 447 def inner(*args, **kwargs): 448 warnings.warn(msg, DeprecationWarning) 449 return function(*args, **kwargs) 450 451 return inner 452 453 return wrapper 454 455 456# ----------------------------------------------------------------------------- 457def experimental(msg: str): 458 """ 459 Throws a future warning before execution. 460 """ 461 462 def wrapper(function): 463 @functools.wraps(function) 464 def inner(*args, **kwargs): 465 warnings.warn(msg, FutureWarning) 466 return function(*args, **kwargs) 467 468 return inner 469 470 return wrapper 471 472 473# ----------------------------------------------------------------------------- 474class OpenIntEnum(enum.IntEnum): 475 """ 476 Subclass of enum.IntEnum that can hold integer values outside the set of 477 predefined values. This is convenient for implementing protocols where some 478 integer constants may be added over time. 479 """ 480 481 @classmethod 482 def _missing_(cls, value): 483 if not isinstance(value, int): 484 return None 485 486 obj = int.__new__(cls, value) 487 obj._value_ = value 488 obj._name_ = f"{cls.__name__}[{value}]" 489 return obj 490