xref: /aosp_15_r20/external/pigweed/pw_docgen/py/pw_docgen/docserver.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2023 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Serve locally-built docs files.
15
16There are essentially four components here:
17
181. A simple HTTP server that serves docs out of the build directory.
19
202. JavaScript that tells a doc to refresh itself when it receives a message
21   that its source file has been changed. This is injected into each served
22   page by #1.
23
243. A WebSocket server that pushes refresh messages to pages generated by #1
25   using the WebSocket client included in #2.
26
274. A very simple file watcher that looks for changes in the built docs files
28   and pushes messages about changed files to #3.
29"""
30
31import asyncio
32import http.server
33import io
34import logging
35from pathlib import Path
36import socketserver
37import threading
38from tempfile import TemporaryFile
39from typing import Callable
40
41from watchdog.events import FileModifiedEvent, FileSystemEventHandler
42from watchdog.observers import Observer
43import websockets
44
45_LOG = logging.getLogger('pw_docgen.docserver')
46
47
48def _generate_script(path: str, host: str, port: str) -> bytes:
49    """Generate the JavaScript to inject into served docs pages."""
50    return f"""<script>
51    var connection = null;
52    var originFilePath = "{path}";
53
54    function watchForReload() {{
55        connection = new WebSocket("ws://{host}:{port}/");
56        console.log("Connecting to WebSocket server...");
57
58        connection.onopen = function () {{
59            console.log("Connected to WebSocket server");
60        }}
61
62        connection.onerror = function () {{
63            console.log("WebSocket connection disconnected or failed");
64        }}
65
66        connection.onmessage = function (message) {{
67            if (message.data === originFilePath) {{
68                window.location.reload(true);
69            }}
70        }}
71    }}
72
73    watchForReload();
74</script>
75</body>
76""".encode(
77        "utf-8"
78    )
79
80
81class OpenAndInjectScript:
82    """A substitute for `open` that injects the refresh handler script.
83
84    Instead of returning a handle to the file you asked for, it returns a
85    handle to a temporary file which has been modified. That file will
86    disappear as soon as it is `.close()`ed, but that has to be done manually;
87    it will not close automatically when exiting scope.
88
89    The instance stores the last path that was opened in `path`.
90    """
91
92    def __init__(self, host: str, port: str):
93        self.path: str = ""
94        self._host = host
95        self._port = port
96
97    def __call__(self, path: str, mode: str) -> io.BufferedReader:
98        if 'b' not in mode:
99            raise ValueError(
100                "This should only be used to open files in binary mode."
101            )
102
103        content = (
104            Path(path)
105            .read_bytes()
106            .replace(b"</body>", _generate_script(path, self._host, self._port))
107        )
108
109        tempfile = TemporaryFile('w+b')
110        tempfile.write(content)
111        # Let the caller read the file like it's just been opened.
112        tempfile.seek(0)
113        # Store the path that held the original file.
114        self.path = path
115        return tempfile  # type: ignore
116
117
118def _docs_http_server(
119    address: str, port: int, path: Path
120) -> Callable[[], None]:
121    """A simple file system-based HTTP server for built docs."""
122
123    class DocsStaticRequestHandler(http.server.SimpleHTTPRequestHandler):
124        def __init__(self, *args, **kwargs):
125            super().__init__(*args, directory=str(path), **kwargs)
126
127        # Disable logs to stdout.
128        def log_message(
129            self, format: str, *args  # pylint: disable=redefined-builtin
130        ) -> None:
131            return
132
133    def http_server_thread():
134        with socketserver.TCPServer(
135            (address, port), DocsStaticRequestHandler
136        ) as httpd:
137            httpd.serve_forever()
138
139    return http_server_thread
140
141
142class TaskFinishedException(Exception):
143    """Indicates one task has completed successfully."""
144
145
146class WebSocketConnectionClosedException(Exception):
147    """Indicates that the WebSocket connection has been closed."""
148
149
150class DocsWebsocketRequestHandler:
151    """WebSocket server that sends page refresh info to clients.
152
153    Push messages to the message queue to broadcast them to all connected
154    clients.
155    """
156
157    def __init__(self, address: str = '127.0.0.1', port: int = 8765):
158        self._address = address
159        self._port = port
160        self._connections = set()  # type: ignore
161        self._messages: asyncio.Queue = asyncio.Queue()
162
163    async def _register_connection(self, websocket) -> None:
164        """Handle client connections and their event loops."""
165        self._connections.add(websocket)
166        _LOG.info("Client connection established: %s", websocket.id)
167
168        while True:
169            try:
170                # Run all of these tasks simultaneously. We don't wait for *all*
171                # of them to finish -- when one finishes, it raises one of the
172                # flow control exceptions to determine what happens next.
173                await asyncio.gather(
174                    self._send_messages(),
175                    self._drop_lost_connection(websocket),
176                )
177            except TaskFinishedException:
178                _LOG.debug("One awaited task finished; iterating event loop.")
179            except WebSocketConnectionClosedException:
180                _LOG.debug("WebSocket connection closed; ending event loop.")
181                return
182
183    async def _drop_lost_connection(self, websocket) -> None:
184        """Remove connections to clients with no heartbeat."""
185        await asyncio.sleep(1)
186
187        if websocket.closed:
188            self._connections.remove(websocket)
189            _LOG.info("Client connection dropped: %s", websocket.id)
190            raise WebSocketConnectionClosedException
191
192        _LOG.debug("Client connection heartbeat active: %s", websocket.id)
193        raise TaskFinishedException
194
195    async def _send_messages(self) -> None:
196        """Send the messages in the message queue to all clients.
197
198        Every page change is broadcast to every client. It is up to the client
199        to determine whether the contents of a messages means it should refresh.
200        This is a pretty easy determination to make though -- the client knows
201        its own source file's path, so it just needs to check if the path in the
202        message matches it.
203        """
204        message = await self._messages.get()
205        websockets.broadcast(self._connections, message)  # type: ignore # pylint: disable=no-member
206        _LOG.info("Sent to %d clients: %s", len(self._connections), message)
207        raise TaskFinishedException
208
209    async def _run(self) -> None:
210        self._messages = asyncio.Queue()
211
212        async with websockets.serve(  # type: ignore # pylint: disable=no-member
213            self._register_connection, self._address, self._port
214        ):
215            await asyncio.Future()
216
217    def push_message(self, message: str) -> None:
218        """Push a message on to the message queue."""
219        if len(self._connections) > 0:
220            self._messages.put_nowait(message)
221            _LOG.info("Pushed to message queue: %s", message)
222
223    def run(self):
224        """Run the WebSocket server."""
225        asyncio.run(self._run())
226
227
228class DocsFileChangeEventHandler(FileSystemEventHandler):
229    """Handle watched built doc files events."""
230
231    def __init__(self, ws_handler: DocsWebsocketRequestHandler) -> None:
232        self._ws_handler = ws_handler
233
234    def on_modified(self, event) -> None:
235        if isinstance(event, FileModifiedEvent):
236            # Push the path of the modified file to the WebSocket server's
237            # message queue.
238            path = Path(event.src_path).relative_to(Path.cwd())
239            self._ws_handler.push_message(str(path))
240
241        return super().on_modified(event)
242
243
244class DocsFileChangeObserver(Observer):  # pylint: disable=too-many-ancestors
245    """Watch for changes to built docs files."""
246
247    def __init__(
248        self, path: str, event_handler: FileSystemEventHandler, *args, **kwargs
249    ):
250        super().__init__(*args, **kwargs)
251        self.schedule(event_handler, path, recursive=True)
252        _LOG.info("Watching build docs files at: %s", path)
253
254
255def serve_docs(
256    build_dir: Path,
257    docs_path: Path,
258    address: str = '127.0.0.1',
259    port: int = 8000,
260    ws_port: int = 8765,
261) -> None:
262    """Run the docs server.
263
264    This actually spawns three threads, one each for the HTTP server, the
265    WebSockets server, and the file watcher.
266    """
267    docs_path = build_dir.joinpath(docs_path.joinpath('html'))
268    http_server_thread = _docs_http_server(address, port, docs_path)
269
270    # The `http.server.SimpleHTTPRequestHandler.send_head` method loads the
271    # HTML file from disk, generates and sends headers to the client, then
272    # passes the file to the HTTP request handlers. We need to modify the file
273    # in the middle of the process, and the only facility we have for doing that
274    # is the somewhat distasteful patching of `open`.
275    _open_and_inject_script = OpenAndInjectScript(address, str(ws_port))
276    setattr(http.server, 'open', _open_and_inject_script)
277
278    websocket_server = DocsWebsocketRequestHandler(address, ws_port)
279    event_handler = DocsFileChangeEventHandler(websocket_server)
280
281    threading.Thread(None, websocket_server.run, 'pw_docserver_ws').start()
282    threading.Thread(None, http_server_thread, 'pw_docserver_http').start()
283    DocsFileChangeObserver(str(docs_path), event_handler).start()
284
285    _LOG.info('Serving docs at http://%s:%d', address, port)
286
287
288async def ws_client(
289    address: str = '127.0.0.1',
290    port: int = 8765,
291):
292    """A simple WebSocket client, useful for testing.
293
294    Run it like this: `asyncio.run(ws_client())`
295    """
296    async with websockets.connect(f"ws://{address}:{port}") as websocket:  # type: ignore # pylint: disable=no-member
297        _LOG.info("Connection ID: %s", websocket.id)
298        async for message in websocket:
299            _LOG.info("Message received: %s", message)
300