xref: /aosp_15_r20/external/pigweed/pw_console/py/pw_console/web_kernel.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2024 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""REPL Kernel for web console."""
15
16import asyncio
17import io
18import json
19import logging
20import sys
21import types
22from typing import Any
23
24from aiohttp.web_ws import WebSocketResponse
25from prompt_toolkit.completion import (
26    CompleteEvent,
27    merge_completers,
28    Completion,
29)
30from prompt_toolkit.document import Document
31from ptpython.completer import PythonCompleter, Completer
32from ptpython.repl import _has_coroutine_flag
33
34from pw_console.embed import create_word_completer
35from pw_console.python_logging import log_record_to_dict
36
37_LOG = logging.getLogger(__package__)
38
39
40class UnknownRequestType(Exception):
41    """Exception for request with unknown or missing types."""
42
43
44class UnknownRequestData(Exception):
45    """Exception for request with missing data attributes."""
46
47
48class MissingCallId(Exception):
49    """Exception for request with missing call id."""
50
51
52def process_partial_expressions(
53    suggestions: list[Completion],
54) -> list[Completion]:
55    """
56    Some completions returned are full expressions, we need to trim them.
57
58    Example:
59    if input is 'device.rp', WordCompleter suggests:
60      'device.rpcs.blinky.Blinky.Blink'
61    We actually need `rpcs.blinky.Blinky`
62    """
63    processed_completions = []
64    for suggestion in suggestions:
65        completion_text = suggestion.text
66        start_position = suggestion.start_position
67
68        # Handle negative start positions
69        if start_position < 0:
70            last_dot = completion_text.rfind('.', 0, -1 * start_position)
71            if last_dot != -1:
72                # We are completing a nested property
73                trimmed_text = completion_text[last_dot + 1 :]
74            else:
75                # No dot found, return full expression
76                trimmed_text = completion_text
77            processed_completions.append(
78                Completion(trimmed_text, 0, suggestion.display)
79            )
80        else:
81            # Return full expression
82            processed_completions.append(
83                Completion(completion_text, 0, suggestion.display)
84            )
85
86    return processed_completions
87
88
89def format_completions(
90    all_completions: list[Completion],
91) -> list[dict[str, str]]:
92    all_completions = process_partial_expressions(
93        # Hide private suggestions
94        [
95            completion
96            for completion in all_completions
97            if not completion.text.startswith('_')
98        ]
99    )
100
101    return list(
102        map(
103            lambda x: {
104                'text': x.text,
105                'type': (
106                    'keyword' if x.display[0][1].endswith('()') else 'variable'
107                ),
108            },
109            all_completions,
110        )
111    )
112
113
114def _format_result_output(result) -> str:
115    """Return a plaintext repr of any object."""
116    try:
117        formatted_result = repr(result)
118    except BaseException:  # pylint: disable=broad-exception-caught
119        formatted_result = ''
120        # Exception is handled below instead of here.
121    return formatted_result
122
123
124def compile_code(code: str, mode: str) -> types.CodeType:
125    return compile(
126        code,
127        '<stdin>',
128        mode,
129        dont_inherit=True,
130    )
131
132
133class WebSocketStreamingResponder(logging.Handler):
134    """Python logging handler that sends json serialized logs.
135
136    Args:
137      connection: the WebSocketResponse object to send json logs to.
138      loop: The asyncio loop to run the send_str in.
139    """
140
141    def __init__(
142        self,
143        connection: WebSocketResponse,
144        loop: asyncio.AbstractEventLoop,
145        *args,
146        **kwargs,
147    ) -> None:
148        self.connection = connection
149        self.loop = loop
150        self.request_ids: list[int] = []
151        super().__init__(*args, **kwargs)
152
153    def emit(self, record: logging.LogRecord) -> None:
154        # Process this log record in a separate event loop.
155        asyncio.run_coroutine_threadsafe(
156            self._process_log(record),
157            self.loop,
158        )
159
160    async def _process_log(self, record: logging.LogRecord) -> None:
161        """Send the log serialized to json via the current WebSocketResponse."""
162        for req_id in self.request_ids:
163            await self.connection.send_str(
164                json.dumps(
165                    {
166                        'id': req_id,
167                        'streaming': True,
168                        'data': {'log_line': log_record_to_dict(record)},
169                    }
170                )
171            )
172
173
174class WebKernel:
175    """Web Kernel implementation."""
176
177    def __init__(
178        self,
179        connection: WebSocketResponse,
180        kernel_params: dict[str, Any],
181        loop: asyncio.AbstractEventLoop,
182    ) -> None:
183        """Create a new kernel for this particular websocket connection."""
184        self.connection = connection
185        self.kernel_params = kernel_params
186        # Make sure global and local vars are not set to None.
187        if kernel_params.get('global_vars', None) is None:
188            self.kernel_params['global_vars'] = {}
189        if kernel_params.get('local_vars', None) is None:
190            self.kernel_params['local_vars'] = {}
191        self.loop = loop
192
193        self.logger_handlers: dict[str, WebSocketStreamingResponder] = {}
194        self.connected = False
195        python_completer = PythonCompleter(
196            self.get_globals,
197            self.get_locals,
198            lambda: True,
199        )
200        all_completers: list[Completer] = [python_completer]
201
202        if kernel_params.get('sentence_completions'):
203            word_completer = create_word_completer(
204                kernel_params.get('sentence_completions', {})
205            )
206            all_completers.append(word_completer)
207
208        # Merge default Python completer with the new custom one.
209        self.completer = merge_completers(all_completers)
210
211    async def handle_request(self, request) -> str:
212        """Handle the request from web browser."""
213        try:
214            parsed_request = json.loads(request)
215            request_type = parsed_request.get('type', None)
216            request_data = parsed_request.get('data', None)
217            call_id = parsed_request.get('id', None)
218            if request_type is None:
219                raise UnknownRequestType(
220                    'Unknown request type: {}'.format(parsed_request)
221                )
222            if request_data is None:
223                raise UnknownRequestData(
224                    'Unknown request data: {}'.format(parsed_request)
225                )
226
227            if call_id is None:
228                raise MissingCallId(
229                    'Missing call id: {}'.format(parsed_request)
230                )
231
232            if request_type == 'autocomplete':
233                try:
234                    completions = self.handle_autocompletion(
235                        request_data['code'],
236                        request_data['cursor_pos'],
237                    )
238                    return json.dumps({'id': call_id, 'data': completions})
239                except KeyError as error:
240                    raise KeyError(
241                        (
242                            'Missing data.code or data.cursor_pos attributes:'
243                            '{}'.format(request_data)
244                        )
245                    ) from error
246
247            if request_type == 'eval':
248                try:
249                    result = await self.handle_eval(request_data['code'])
250                    return json.dumps({'id': call_id, 'data': result})
251                except KeyError as error:
252                    raise KeyError(
253                        'Missing data.code attributes: {}'.format(request_data)
254                    ) from error
255
256            if request_type == 'log_source_list':
257                sources = self.handle_log_source_list()
258                return json.dumps({'id': call_id, 'data': sources})
259            if request_type == 'log_source_subscribe':
260                try:
261                    # Close requests have same call id with just .close = true
262                    if 'close' in parsed_request:
263                        has_unsubbed = self.handle_log_source_unsubscribe(
264                            request_data['name'], call_id
265                        )
266                        return json.dumps(
267                            {
268                                'id': call_id,
269                                'streaming': True,
270                                'data': has_unsubbed,
271                            }
272                        )
273                    has_subscribed = self.handle_log_source_subscribe(
274                        request_data['name'], call_id
275                    )
276                    return json.dumps(
277                        {
278                            'id': call_id,
279                            'streaming': True,
280                            'data': has_subscribed,
281                        }
282                    )
283                except KeyError as error:
284                    raise KeyError(
285                        'Missing data.name attributes: {}'.format(request_data)
286                    ) from error
287
288            return 'unknown'
289        except ValueError:
290            _LOG.error('Failed to parse request: %s', request)
291            return ''
292
293    async def handle_eval(self, code: str) -> dict[str, str] | None:
294        """Evaluate user code and return output."""
295        # Patch stdout and stderr to capture repl print() statements.
296        temp_stdout = io.StringIO()
297        temp_stderr = io.StringIO()
298        original_stdout = sys.stdout
299        original_stderr = sys.stderr
300
301        sys.stdout = temp_stdout
302        sys.stderr = temp_stderr
303
304        def return_result_with_stdout_stderr(result) -> dict[str, str]:
305            # Always restore original stdout and stderr
306            sys.stdout = original_stdout
307            sys.stderr = original_stderr
308
309            return {
310                'result': (
311                    _format_result_output(result) if result is not None else ''
312                ),
313                'stdout': temp_stdout.getvalue(),
314                'stderr': temp_stderr.getvalue(),
315            }
316
317        try:
318            result = await self._eval_async(code)
319        except KeyboardInterrupt:
320            return_result_with_stdout_stderr(None)
321            raise
322        except SystemExit:
323            return None
324        except BaseException as e:  # pylint: disable=broad-exception-caught
325            return return_result_with_stdout_stderr(_format_result_output(e))
326        else:
327            # Print.
328            return return_result_with_stdout_stderr(result)
329
330    async def _eval_async(self, code: str) -> Any:
331        """
332        Evaluate the code and return result
333        """
334
335        # WORKAROUND: Due to a bug in Jedi, the current directory is removed
336        # from sys.path. See: https://github.com/davidhalter/jedi/issues/1148
337        if '' not in sys.path:
338            sys.path.insert(0, '')
339
340        # Try eval first
341        try:
342            compiled_code = compile_code(code, 'eval')
343        except SyntaxError:
344            pass
345        else:
346            # No syntax errors for eval. Do eval.
347            result = eval(  # pylint: disable=eval-used
348                code, self.get_globals(), self.get_locals()
349            )
350
351            if _has_coroutine_flag(compiled_code):
352                result = await result
353
354            return result
355
356        # If not a valid `eval` expression, compile as `exec` expression
357        compiled_code = compile_code(code, 'exec')
358        result = eval(  # pylint: disable=eval-used
359            compiled_code, self.get_globals(), self.get_locals()
360        )
361
362        if _has_coroutine_flag(compiled_code):
363            result = await result
364            return result
365
366        return
367
368    def handle_autocompletion(
369        self, code: str, cursor_pos: int
370    ) -> list[dict[str, str]]:
371        doc = Document(code, cursor_pos)
372        all_completions = list(
373            self.completer.get_completions(
374                doc,
375                CompleteEvent(completion_requested=False, text_inserted=True),
376            )
377        )
378        return format_completions(all_completions)
379
380    def handle_disconnect(self) -> None:
381        _LOG.info('pw_console.web_kernel disconnecting.')
382        self.connected = False
383        # Clean up all log handlers as we are shutting down
384        for logger_name in self.kernel_params['loggers'].keys():
385            for logger in self.kernel_params['loggers'][logger_name]:
386                logger.removeHandler(self.logger_handlers[logger_name])
387
388    def get_globals(self) -> dict[str, Any]:
389        return self.kernel_params.get('global_vars', globals())
390
391    def get_locals(self) -> dict[str, Any]:
392        return self.kernel_params.get('local_vars', self.get_globals())
393
394    def handle_log_source_list(self) -> list[str]:
395        if 'loggers' in self.kernel_params:
396            return list(self.kernel_params['loggers'].keys())
397        return []
398
399    def handle_log_source_subscribe(self, logger_name, request_id) -> bool:
400        if self.kernel_params['loggers'][logger_name]:
401            if logger_name not in self.logger_handlers:
402                self.logger_handlers[logger_name] = WebSocketStreamingResponder(
403                    self.connection,
404                    self.loop,
405                )
406                for logger in self.kernel_params['loggers'][logger_name]:
407                    logger.addHandler(self.logger_handlers[logger_name])
408            self.logger_handlers[logger_name].request_ids.append(request_id)
409            return True
410        return False
411
412    def handle_log_source_unsubscribe(self, logger_name, request_id) -> bool:
413        if (
414            self.kernel_params['loggers'][logger_name]
415            and self.logger_handlers[logger_name]
416        ):
417            self.logger_handlers[logger_name].request_ids.remove(request_id)
418            # Remove handler if all requests have unsubscribed
419            if len(self.logger_handlers[logger_name].request_ids) == 0:
420                for logger in self.kernel_params['loggers'][logger_name]:
421                    logger.removeHandler(self.logger_handlers[logger_name])
422                del self.logger_handlers[logger_name]
423            return True
424        return False
425