1from __future__ import annotations
2
3import dataclasses
4import os
5from queue import Queue
6from typing import Protocol
7
8from watchdog.events import FileSystemEvent
9from watchdog.observers.api import EventEmitter, ObservedWatch
10from watchdog.utils import platform
11
12Emitter: type[EventEmitter]
13
14if platform.is_linux():
15    from watchdog.observers.inotify import InotifyEmitter as Emitter
16    from watchdog.observers.inotify import InotifyFullEmitter
17elif platform.is_darwin():
18    from watchdog.observers.fsevents import FSEventsEmitter as Emitter
19elif platform.is_windows():
20    from watchdog.observers.read_directory_changes import WindowsApiEmitter as Emitter
21elif platform.is_bsd():
22    from watchdog.observers.kqueue import KqueueEmitter as Emitter
23
24
25class P(Protocol):
26    def __call__(self, *args: str) -> str: ...
27
28
29class StartWatching(Protocol):
30    def __call__(
31        self,
32        *,
33        path: bytes | str | None = ...,
34        use_full_emitter: bool = ...,
35        recursive: bool = ...,
36    ) -> EventEmitter: ...
37
38
39class ExpectEvent(Protocol):
40    def __call__(self, expected_event: FileSystemEvent, *, timeout: float = ...) -> None: ...
41
42
43TestEventQueue = Queue[tuple[FileSystemEvent, ObservedWatch]]
44
45
46@dataclasses.dataclass()
47class Helper:
48    tmp: str
49    emitters: list[EventEmitter] = dataclasses.field(default_factory=list)
50    event_queue: TestEventQueue = dataclasses.field(default_factory=Queue)
51
52    def joinpath(self, *args: str) -> str:
53        return os.path.join(self.tmp, *args)
54
55    def start_watching(
56        self,
57        *,
58        path: bytes | str | None = None,
59        use_full_emitter: bool = False,
60        recursive: bool = True,
61    ) -> EventEmitter:
62        # TODO: check if other platforms expect the trailing slash (e.g. `p('')`)
63        path = self.tmp if path is None else path
64
65        watcher = ObservedWatch(path, recursive=recursive)
66        emitter_cls = InotifyFullEmitter if platform.is_linux() and use_full_emitter else Emitter
67        emitter = emitter_cls(self.event_queue, watcher)
68
69        if platform.is_darwin():
70            # TODO: I think this could be better...  .suppress_history should maybe
71            #       become a common attribute.
72            from watchdog.observers.fsevents import FSEventsEmitter
73
74            assert isinstance(emitter, FSEventsEmitter)
75            emitter.suppress_history = True
76
77        self.emitters.append(emitter)
78        emitter.start()
79
80        return emitter
81
82    def expect_event(self, expected_event: FileSystemEvent, timeout: float = 2) -> None:
83        """Utility function to wait up to `timeout` seconds for an `event_type` for `path` to show up in the queue.
84
85        Provides some robustness for the otherwise flaky nature of asynchronous notifications.
86        """
87        assert self.event_queue.get(timeout=timeout)[0] == expected_event
88
89    def close(self) -> None:
90        for emitter in self.emitters:
91            emitter.stop()
92
93        for emitter in self.emitters:
94            if emitter.is_alive():
95                emitter.join(5)
96
97        alive = [emitter.is_alive() for emitter in self.emitters]
98        self.emitters = []
99        assert alive == [False] * len(alive)
100