1# Copyright 2024 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""REPL Kernel for web console.""" 15 16import asyncio 17import io 18import json 19import logging 20import sys 21import types 22from typing import Any 23 24from aiohttp.web_ws import WebSocketResponse 25from prompt_toolkit.completion import ( 26 CompleteEvent, 27 merge_completers, 28 Completion, 29) 30from prompt_toolkit.document import Document 31from ptpython.completer import PythonCompleter, Completer 32from ptpython.repl import _has_coroutine_flag 33 34from pw_console.embed import create_word_completer 35from pw_console.python_logging import log_record_to_dict 36 37_LOG = logging.getLogger(__package__) 38 39 40class UnknownRequestType(Exception): 41 """Exception for request with unknown or missing types.""" 42 43 44class UnknownRequestData(Exception): 45 """Exception for request with missing data attributes.""" 46 47 48class MissingCallId(Exception): 49 """Exception for request with missing call id.""" 50 51 52def process_partial_expressions( 53 suggestions: list[Completion], 54) -> list[Completion]: 55 """ 56 Some completions returned are full expressions, we need to trim them. 57 58 Example: 59 if input is 'device.rp', WordCompleter suggests: 60 'device.rpcs.blinky.Blinky.Blink' 61 We actually need `rpcs.blinky.Blinky` 62 """ 63 processed_completions = [] 64 for suggestion in suggestions: 65 completion_text = suggestion.text 66 start_position = suggestion.start_position 67 68 # Handle negative start positions 69 if start_position < 0: 70 last_dot = completion_text.rfind('.', 0, -1 * start_position) 71 if last_dot != -1: 72 # We are completing a nested property 73 trimmed_text = completion_text[last_dot + 1 :] 74 else: 75 # No dot found, return full expression 76 trimmed_text = completion_text 77 processed_completions.append( 78 Completion(trimmed_text, 0, suggestion.display) 79 ) 80 else: 81 # Return full expression 82 processed_completions.append( 83 Completion(completion_text, 0, suggestion.display) 84 ) 85 86 return processed_completions 87 88 89def format_completions( 90 all_completions: list[Completion], 91) -> list[dict[str, str]]: 92 all_completions = process_partial_expressions( 93 # Hide private suggestions 94 [ 95 completion 96 for completion in all_completions 97 if not completion.text.startswith('_') 98 ] 99 ) 100 101 return list( 102 map( 103 lambda x: { 104 'text': x.text, 105 'type': ( 106 'keyword' if x.display[0][1].endswith('()') else 'variable' 107 ), 108 }, 109 all_completions, 110 ) 111 ) 112 113 114def _format_result_output(result) -> str: 115 """Return a plaintext repr of any object.""" 116 try: 117 formatted_result = repr(result) 118 except BaseException: # pylint: disable=broad-exception-caught 119 formatted_result = '' 120 # Exception is handled below instead of here. 121 return formatted_result 122 123 124def compile_code(code: str, mode: str) -> types.CodeType: 125 return compile( 126 code, 127 '<stdin>', 128 mode, 129 dont_inherit=True, 130 ) 131 132 133class WebSocketStreamingResponder(logging.Handler): 134 """Python logging handler that sends json serialized logs. 135 136 Args: 137 connection: the WebSocketResponse object to send json logs to. 138 loop: The asyncio loop to run the send_str in. 139 """ 140 141 def __init__( 142 self, 143 connection: WebSocketResponse, 144 loop: asyncio.AbstractEventLoop, 145 *args, 146 **kwargs, 147 ) -> None: 148 self.connection = connection 149 self.loop = loop 150 self.request_ids: list[int] = [] 151 super().__init__(*args, **kwargs) 152 153 def emit(self, record: logging.LogRecord) -> None: 154 # Process this log record in a separate event loop. 155 asyncio.run_coroutine_threadsafe( 156 self._process_log(record), 157 self.loop, 158 ) 159 160 async def _process_log(self, record: logging.LogRecord) -> None: 161 """Send the log serialized to json via the current WebSocketResponse.""" 162 for req_id in self.request_ids: 163 await self.connection.send_str( 164 json.dumps( 165 { 166 'id': req_id, 167 'streaming': True, 168 'data': {'log_line': log_record_to_dict(record)}, 169 } 170 ) 171 ) 172 173 174class WebKernel: 175 """Web Kernel implementation.""" 176 177 def __init__( 178 self, 179 connection: WebSocketResponse, 180 kernel_params: dict[str, Any], 181 loop: asyncio.AbstractEventLoop, 182 ) -> None: 183 """Create a new kernel for this particular websocket connection.""" 184 self.connection = connection 185 self.kernel_params = kernel_params 186 # Make sure global and local vars are not set to None. 187 if kernel_params.get('global_vars', None) is None: 188 self.kernel_params['global_vars'] = {} 189 if kernel_params.get('local_vars', None) is None: 190 self.kernel_params['local_vars'] = {} 191 self.loop = loop 192 193 self.logger_handlers: dict[str, WebSocketStreamingResponder] = {} 194 self.connected = False 195 python_completer = PythonCompleter( 196 self.get_globals, 197 self.get_locals, 198 lambda: True, 199 ) 200 all_completers: list[Completer] = [python_completer] 201 202 if kernel_params.get('sentence_completions'): 203 word_completer = create_word_completer( 204 kernel_params.get('sentence_completions', {}) 205 ) 206 all_completers.append(word_completer) 207 208 # Merge default Python completer with the new custom one. 209 self.completer = merge_completers(all_completers) 210 211 async def handle_request(self, request) -> str: 212 """Handle the request from web browser.""" 213 try: 214 parsed_request = json.loads(request) 215 request_type = parsed_request.get('type', None) 216 request_data = parsed_request.get('data', None) 217 call_id = parsed_request.get('id', None) 218 if request_type is None: 219 raise UnknownRequestType( 220 'Unknown request type: {}'.format(parsed_request) 221 ) 222 if request_data is None: 223 raise UnknownRequestData( 224 'Unknown request data: {}'.format(parsed_request) 225 ) 226 227 if call_id is None: 228 raise MissingCallId( 229 'Missing call id: {}'.format(parsed_request) 230 ) 231 232 if request_type == 'autocomplete': 233 try: 234 completions = self.handle_autocompletion( 235 request_data['code'], 236 request_data['cursor_pos'], 237 ) 238 return json.dumps({'id': call_id, 'data': completions}) 239 except KeyError as error: 240 raise KeyError( 241 ( 242 'Missing data.code or data.cursor_pos attributes:' 243 '{}'.format(request_data) 244 ) 245 ) from error 246 247 if request_type == 'eval': 248 try: 249 result = await self.handle_eval(request_data['code']) 250 return json.dumps({'id': call_id, 'data': result}) 251 except KeyError as error: 252 raise KeyError( 253 'Missing data.code attributes: {}'.format(request_data) 254 ) from error 255 256 if request_type == 'log_source_list': 257 sources = self.handle_log_source_list() 258 return json.dumps({'id': call_id, 'data': sources}) 259 if request_type == 'log_source_subscribe': 260 try: 261 # Close requests have same call id with just .close = true 262 if 'close' in parsed_request: 263 has_unsubbed = self.handle_log_source_unsubscribe( 264 request_data['name'], call_id 265 ) 266 return json.dumps( 267 { 268 'id': call_id, 269 'streaming': True, 270 'data': has_unsubbed, 271 } 272 ) 273 has_subscribed = self.handle_log_source_subscribe( 274 request_data['name'], call_id 275 ) 276 return json.dumps( 277 { 278 'id': call_id, 279 'streaming': True, 280 'data': has_subscribed, 281 } 282 ) 283 except KeyError as error: 284 raise KeyError( 285 'Missing data.name attributes: {}'.format(request_data) 286 ) from error 287 288 return 'unknown' 289 except ValueError: 290 _LOG.error('Failed to parse request: %s', request) 291 return '' 292 293 async def handle_eval(self, code: str) -> dict[str, str] | None: 294 """Evaluate user code and return output.""" 295 # Patch stdout and stderr to capture repl print() statements. 296 temp_stdout = io.StringIO() 297 temp_stderr = io.StringIO() 298 original_stdout = sys.stdout 299 original_stderr = sys.stderr 300 301 sys.stdout = temp_stdout 302 sys.stderr = temp_stderr 303 304 def return_result_with_stdout_stderr(result) -> dict[str, str]: 305 # Always restore original stdout and stderr 306 sys.stdout = original_stdout 307 sys.stderr = original_stderr 308 309 return { 310 'result': ( 311 _format_result_output(result) if result is not None else '' 312 ), 313 'stdout': temp_stdout.getvalue(), 314 'stderr': temp_stderr.getvalue(), 315 } 316 317 try: 318 result = await self._eval_async(code) 319 except KeyboardInterrupt: 320 return_result_with_stdout_stderr(None) 321 raise 322 except SystemExit: 323 return None 324 except BaseException as e: # pylint: disable=broad-exception-caught 325 return return_result_with_stdout_stderr(_format_result_output(e)) 326 else: 327 # Print. 328 return return_result_with_stdout_stderr(result) 329 330 async def _eval_async(self, code: str) -> Any: 331 """ 332 Evaluate the code and return result 333 """ 334 335 # WORKAROUND: Due to a bug in Jedi, the current directory is removed 336 # from sys.path. See: https://github.com/davidhalter/jedi/issues/1148 337 if '' not in sys.path: 338 sys.path.insert(0, '') 339 340 # Try eval first 341 try: 342 compiled_code = compile_code(code, 'eval') 343 except SyntaxError: 344 pass 345 else: 346 # No syntax errors for eval. Do eval. 347 result = eval( # pylint: disable=eval-used 348 code, self.get_globals(), self.get_locals() 349 ) 350 351 if _has_coroutine_flag(compiled_code): 352 result = await result 353 354 return result 355 356 # If not a valid `eval` expression, compile as `exec` expression 357 compiled_code = compile_code(code, 'exec') 358 result = eval( # pylint: disable=eval-used 359 compiled_code, self.get_globals(), self.get_locals() 360 ) 361 362 if _has_coroutine_flag(compiled_code): 363 result = await result 364 return result 365 366 return 367 368 def handle_autocompletion( 369 self, code: str, cursor_pos: int 370 ) -> list[dict[str, str]]: 371 doc = Document(code, cursor_pos) 372 all_completions = list( 373 self.completer.get_completions( 374 doc, 375 CompleteEvent(completion_requested=False, text_inserted=True), 376 ) 377 ) 378 return format_completions(all_completions) 379 380 def handle_disconnect(self) -> None: 381 _LOG.info('pw_console.web_kernel disconnecting.') 382 self.connected = False 383 # Clean up all log handlers as we are shutting down 384 for logger_name in self.kernel_params['loggers'].keys(): 385 for logger in self.kernel_params['loggers'][logger_name]: 386 logger.removeHandler(self.logger_handlers[logger_name]) 387 388 def get_globals(self) -> dict[str, Any]: 389 return self.kernel_params.get('global_vars', globals()) 390 391 def get_locals(self) -> dict[str, Any]: 392 return self.kernel_params.get('local_vars', self.get_globals()) 393 394 def handle_log_source_list(self) -> list[str]: 395 if 'loggers' in self.kernel_params: 396 return list(self.kernel_params['loggers'].keys()) 397 return [] 398 399 def handle_log_source_subscribe(self, logger_name, request_id) -> bool: 400 if self.kernel_params['loggers'][logger_name]: 401 if logger_name not in self.logger_handlers: 402 self.logger_handlers[logger_name] = WebSocketStreamingResponder( 403 self.connection, 404 self.loop, 405 ) 406 for logger in self.kernel_params['loggers'][logger_name]: 407 logger.addHandler(self.logger_handlers[logger_name]) 408 self.logger_handlers[logger_name].request_ids.append(request_id) 409 return True 410 return False 411 412 def handle_log_source_unsubscribe(self, logger_name, request_id) -> bool: 413 if ( 414 self.kernel_params['loggers'][logger_name] 415 and self.logger_handlers[logger_name] 416 ): 417 self.logger_handlers[logger_name].request_ids.remove(request_id) 418 # Remove handler if all requests have unsubscribed 419 if len(self.logger_handlers[logger_name].request_ids) == 0: 420 for logger in self.kernel_params['loggers'][logger_name]: 421 logger.removeHandler(self.logger_handlers[logger_name]) 422 del self.logger_handlers[logger_name] 423 return True 424 return False 425