1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19import asyncio
20import collections
21import enum
22import functools
23import logging
24import sys
25import warnings
26from typing import (
27    Awaitable,
28    Set,
29    TypeVar,
30    List,
31    Tuple,
32    Callable,
33    Any,
34    Optional,
35    Union,
36    overload,
37)
38
39from pyee import EventEmitter
40
41from .colors import color
42
43# -----------------------------------------------------------------------------
44# Logging
45# -----------------------------------------------------------------------------
46logger = logging.getLogger(__name__)
47
48
49# -----------------------------------------------------------------------------
50def setup_event_forwarding(emitter, forwarder, event_name):
51    def emit(*args, **kwargs):
52        forwarder.emit(event_name, *args, **kwargs)
53
54    emitter.on(event_name, emit)
55
56
57# -----------------------------------------------------------------------------
58def composite_listener(cls):
59    """
60    Decorator that adds a `register` and `deregister` method to a class, which
61    registers/deregisters all methods named `on_<event_name>` as a listener for
62    the <event_name> event with an emitter.
63    """
64    # pylint: disable=protected-access
65
66    def register(self, emitter):
67        for method_name in dir(cls):
68            if method_name.startswith('on_'):
69                emitter.on(method_name[3:], getattr(self, method_name))
70
71    def deregister(self, emitter):
72        for method_name in dir(cls):
73            if method_name.startswith('on_'):
74                emitter.remove_listener(method_name[3:], getattr(self, method_name))
75
76    cls._bumble_register_composite = register
77    cls._bumble_deregister_composite = deregister
78    return cls
79
80
81# -----------------------------------------------------------------------------
82_Handler = TypeVar('_Handler', bound=Callable)
83
84
85class EventWatcher:
86    '''A wrapper class to control the lifecycle of event handlers better.
87
88    Usage:
89    ```
90    watcher = EventWatcher()
91
92    def on_foo():
93        ...
94    watcher.on(emitter, 'foo', on_foo)
95
96    @watcher.on(emitter, 'bar')
97    def on_bar():
98        ...
99
100    # Close all event handlers watching through this watcher
101    watcher.close()
102    ```
103
104    As context:
105    ```
106    with contextlib.closing(EventWatcher()) as context:
107        @context.on(emitter, 'foo')
108        def on_foo():
109            ...
110    # on_foo() has been removed here!
111    ```
112    '''
113
114    handlers: List[Tuple[EventEmitter, str, Callable[..., Any]]]
115
116    def __init__(self) -> None:
117        self.handlers = []
118
119    @overload
120    def on(
121        self, emitter: EventEmitter, event: str
122    ) -> Callable[[_Handler], _Handler]: ...
123
124    @overload
125    def on(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler: ...
126
127    def on(
128        self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
129    ) -> Union[_Handler, Callable[[_Handler], _Handler]]:
130        '''Watch an event until the context is closed.
131
132        Args:
133            emitter: EventEmitter to watch
134            event: Event name
135            handler: (Optional) Event handler. When nothing is passed, this method
136            works as a decorator.
137        '''
138
139        def wrapper(wrapped: _Handler) -> _Handler:
140            self.handlers.append((emitter, event, wrapped))
141            emitter.on(event, wrapped)
142            return wrapped
143
144        return wrapper if handler is None else wrapper(handler)
145
146    @overload
147    def once(
148        self, emitter: EventEmitter, event: str
149    ) -> Callable[[_Handler], _Handler]: ...
150
151    @overload
152    def once(
153        self, emitter: EventEmitter, event: str, handler: _Handler
154    ) -> _Handler: ...
155
156    def once(
157        self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
158    ) -> Union[_Handler, Callable[[_Handler], _Handler]]:
159        '''Watch an event for once.
160
161        Args:
162            emitter: EventEmitter to watch
163            event: Event name
164            handler: (Optional) Event handler. When nothing passed, this method works
165            as a decorator.
166        '''
167
168        def wrapper(wrapped: _Handler) -> _Handler:
169            self.handlers.append((emitter, event, wrapped))
170            emitter.once(event, wrapped)
171            return wrapped
172
173        return wrapper if handler is None else wrapper(handler)
174
175    def close(self) -> None:
176        for emitter, event, handler in self.handlers:
177            if handler in emitter.listeners(event):
178                emitter.remove_listener(event, handler)
179
180
181# -----------------------------------------------------------------------------
182_T = TypeVar('_T')
183
184
185class AbortableEventEmitter(EventEmitter):
186    def abort_on(self, event: str, awaitable: Awaitable[_T]) -> Awaitable[_T]:
187        """
188        Set a coroutine or future to abort when an event occur.
189        """
190        future = asyncio.ensure_future(awaitable)
191        if future.done():
192            return future
193
194        def on_event(*_):
195            if future.done():
196                return
197            msg = f'abort: {event} event occurred.'
198            if isinstance(future, asyncio.Task):
199                # python < 3.9 does not support passing a message on `Task.cancel`
200                if sys.version_info < (3, 9, 0):
201                    future.cancel()
202                else:
203                    future.cancel(msg)
204            else:
205                future.set_exception(asyncio.CancelledError(msg))
206
207        def on_done(_):
208            self.remove_listener(event, on_event)
209
210        self.on(event, on_event)
211        future.add_done_callback(on_done)
212        return future
213
214
215# -----------------------------------------------------------------------------
216class CompositeEventEmitter(AbortableEventEmitter):
217    def __init__(self):
218        super().__init__()
219        self._listener = None
220
221    @property
222    def listener(self):
223        return self._listener
224
225    @listener.setter
226    def listener(self, listener):
227        # pylint: disable=protected-access
228        if self._listener:
229            # Call the deregistration methods for each base class that has them
230            for cls in self._listener.__class__.mro():
231                if '_bumble_register_composite' in cls.__dict__:
232                    cls._bumble_deregister_composite(self._listener, self)
233        self._listener = listener
234        if listener:
235            # Call the registration methods for each base class that has them
236            for cls in listener.__class__.mro():
237                if '_bumble_deregister_composite' in cls.__dict__:
238                    cls._bumble_register_composite(listener, self)
239
240
241# -----------------------------------------------------------------------------
242class AsyncRunner:
243    class WorkQueue:
244        def __init__(self, create_task=True):
245            self.queue = None
246            self.task = None
247            self.create_task = create_task
248
249        def enqueue(self, coroutine):
250            # Create a task now if we need to and haven't done so already
251            if self.create_task and self.task is None:
252                self.task = asyncio.create_task(self.run())
253
254            # Lazy-create the coroutine queue
255            if self.queue is None:
256                self.queue = asyncio.Queue()
257
258            # Enqueue the work
259            self.queue.put_nowait(coroutine)
260
261        async def run(self):
262            while True:
263                item = await self.queue.get()
264                try:
265                    await item
266                except Exception as error:
267                    logger.warning(
268                        f'{color("!!! Exception in work queue:", "red")} {error}'
269                    )
270
271    # Shared default queue
272    default_queue = WorkQueue()
273
274    # Shared set of running tasks
275    running_tasks: Set[Awaitable] = set()
276
277    @staticmethod
278    def run_in_task(queue=None):
279        """
280        Function decorator used to adapt an async function into a sync function
281        """
282
283        def decorator(func):
284            @functools.wraps(func)
285            def wrapper(*args, **kwargs):
286                coroutine = func(*args, **kwargs)
287                if queue is None:
288                    # Spawn the coroutine as a task
289                    async def run():
290                        try:
291                            await coroutine
292                        except Exception:
293                            logger.exception(color("!!! Exception in wrapper:", "red"))
294
295                    AsyncRunner.spawn(run())
296                else:
297                    # Queue the coroutine to be awaited by the work queue
298                    queue.enqueue(coroutine)
299
300            return wrapper
301
302        return decorator
303
304    @staticmethod
305    def spawn(coroutine):
306        """
307        Spawn a task to run a coroutine in a "fire and forget" mode.
308
309        Using this method instead of just calling `asyncio.create_task(coroutine)`
310        is necessary when you don't keep a reference to the task, because `asyncio`
311        only keeps weak references to alive tasks.
312        """
313        task = asyncio.create_task(coroutine)
314        AsyncRunner.running_tasks.add(task)
315        task.add_done_callback(AsyncRunner.running_tasks.remove)
316
317
318# -----------------------------------------------------------------------------
319class FlowControlAsyncPipe:
320    """
321    Asyncio pipe with flow control. When writing to the pipe, the source is
322    paused (by calling a function passed in when the pipe is created) if the
323    amount of queued data exceeds a specified threshold.
324    """
325
326    def __init__(
327        self,
328        pause_source,
329        resume_source,
330        write_to_sink=None,
331        drain_sink=None,
332        threshold=0,
333    ):
334        self.pause_source = pause_source
335        self.resume_source = resume_source
336        self.write_to_sink = write_to_sink
337        self.drain_sink = drain_sink
338        self.threshold = threshold
339        self.queue = collections.deque()  # Queue of packets
340        self.queued_bytes = 0  # Number of bytes in the queue
341        self.ready_to_pump = asyncio.Event()
342        self.paused = False
343        self.source_paused = False
344        self.pump_task = None
345
346    def start(self):
347        if self.pump_task is None:
348            self.pump_task = asyncio.create_task(self.pump())
349
350        self.check_pump()
351
352    def stop(self):
353        if self.pump_task is not None:
354            self.pump_task.cancel()
355            self.pump_task = None
356
357    def write(self, packet):
358        self.queued_bytes += len(packet)
359        self.queue.append(packet)
360
361        # Pause the source if we're over the threshold
362        if self.queued_bytes > self.threshold and not self.source_paused:
363            logger.debug(f'pausing source (queued={self.queued_bytes})')
364            self.pause_source()
365            self.source_paused = True
366
367        self.check_pump()
368
369    def pause(self):
370        if not self.paused:
371            self.paused = True
372            if not self.source_paused:
373                self.pause_source()
374                self.source_paused = True
375            self.check_pump()
376
377    def resume(self):
378        if self.paused:
379            self.paused = False
380            if self.source_paused:
381                self.resume_source()
382                self.source_paused = False
383            self.check_pump()
384
385    def can_pump(self):
386        return self.queue and not self.paused and self.write_to_sink is not None
387
388    def check_pump(self):
389        if self.can_pump():
390            self.ready_to_pump.set()
391        else:
392            self.ready_to_pump.clear()
393
394    async def pump(self):
395        while True:
396            # Wait until we can try to pump packets
397            await self.ready_to_pump.wait()
398
399            # Try to pump a packet
400            if self.can_pump():
401                packet = self.queue.pop()
402                self.write_to_sink(packet)
403                self.queued_bytes -= len(packet)
404
405                # Drain the sink if we can
406                if self.drain_sink:
407                    await self.drain_sink()
408
409                # Check if we can accept more
410                if self.queued_bytes <= self.threshold and self.source_paused:
411                    logger.debug(f'resuming source (queued={self.queued_bytes})')
412                    self.source_paused = False
413                    self.resume_source()
414
415            self.check_pump()
416
417
418# -----------------------------------------------------------------------------
419async def async_call(function, *args, **kwargs):
420    """
421    Immediately calls the function with provided args and kwargs, wrapping it in an
422    async function.
423    Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject
424    a running loop.
425
426    result = await async_call(some_function, ...)
427    """
428    return function(*args, **kwargs)
429
430
431# -----------------------------------------------------------------------------
432def wrap_async(function):
433    """
434    Wraps the provided function in an async function.
435    """
436    return functools.partial(async_call, function)
437
438
439# -----------------------------------------------------------------------------
440def deprecated(msg: str):
441    """
442    Throw deprecation warning before execution.
443    """
444
445    def wrapper(function):
446        @functools.wraps(function)
447        def inner(*args, **kwargs):
448            warnings.warn(msg, DeprecationWarning)
449            return function(*args, **kwargs)
450
451        return inner
452
453    return wrapper
454
455
456# -----------------------------------------------------------------------------
457def experimental(msg: str):
458    """
459    Throws a future warning before execution.
460    """
461
462    def wrapper(function):
463        @functools.wraps(function)
464        def inner(*args, **kwargs):
465            warnings.warn(msg, FutureWarning)
466            return function(*args, **kwargs)
467
468        return inner
469
470    return wrapper
471
472
473# -----------------------------------------------------------------------------
474class OpenIntEnum(enum.IntEnum):
475    """
476    Subclass of enum.IntEnum that can hold integer values outside the set of
477    predefined values. This is convenient for implementing protocols where some
478    integer constants may be added over time.
479    """
480
481    @classmethod
482    def _missing_(cls, value):
483        if not isinstance(value, int):
484            return None
485
486        obj = int.__new__(cls, value)
487        obj._value_ = value
488        obj._name_ = f"{cls.__name__}[{value}]"
489        return obj
490