1# Copyright 2021 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Device classes to interact with targets via RPC.""" 15 16 17import logging 18import os 19from pathlib import Path 20import tempfile 21from types import ModuleType 22from collections.abc import Iterable 23from typing import Any, Callable, Sequence 24 25from pw_file import file_pb2 26from pw_hdlc import rpc 27from pw_hdlc.decode import Frame 28from pw_log import log_decoder 29from pw_log_rpc import rpc_log_stream 30from pw_metric import metric_parser 31import pw_rpc 32from pw_rpc import callback_client, console_tools, client_utils 33import pw_transfer 34from pw_transfer import transfer_pb2 35from pw_stream import stream_readers 36from pw_system import snapshot 37from pw_thread import thread_analyzer 38from pw_thread_protos import thread_pb2 39from pw_tokenizer import detokenize 40from pw_tokenizer.proto import decode_optionally_tokenized 41from pw_unit_test.rpc import run_tests as pw_unit_test_run_tests, TestRecord 42 43 44# Internal log for troubleshooting this tool (the console). 45_LOG = logging.getLogger(__package__) 46 47DEFAULT_DEVICE_LOGGER = logging.getLogger('rpc_device') 48 49 50class Device: 51 """Represents an RPC Client for a device running a Pigweed target. 52 53 The target must have RPC support for the following services: 54 - logging 55 - file 56 - transfer 57 58 Note: use this class as a base for specialized device representations. 59 """ 60 61 def __init__( 62 # pylint: disable=too-many-arguments 63 self, 64 channel_id: int, 65 reader: stream_readers.CancellableReader, 66 write: Callable[[bytes], Any], 67 proto_library: Iterable[ModuleType | Path], 68 detokenizer: detokenize.Detokenizer | None = None, 69 timestamp_decoder: Callable[[int], str] | None = None, 70 rpc_timeout_s: float = 5, 71 use_rpc_logging: bool = True, 72 use_hdlc_encoding: bool = True, 73 logger: logging.Logger | logging.LoggerAdapter = DEFAULT_DEVICE_LOGGER, 74 extra_frame_handlers: dict[int, Callable[[bytes, Any], Any]] 75 | None = None, 76 ): 77 self.channel_id = channel_id 78 self.protos = list(proto_library) 79 self.detokenizer = detokenizer 80 self.rpc_timeout_s = rpc_timeout_s 81 82 self.logger = logger 83 self.logger.setLevel(logging.DEBUG) # Allow all device logs through. 84 85 callback_client_impl = callback_client.Impl( 86 default_unary_timeout_s=self.rpc_timeout_s, 87 default_stream_timeout_s=None, 88 ) 89 90 def detokenize_and_log_output(data: bytes, _detokenizer=None): 91 log_messages = data.decode( 92 encoding='utf-8', errors='surrogateescape' 93 ) 94 95 if self.detokenizer: 96 log_messages = decode_optionally_tokenized( 97 self.detokenizer, data 98 ) 99 100 for line in log_messages.splitlines(): 101 self.logger.info(line) 102 103 # Device has a hard dependency on transfer_pb2, so ensure it's 104 # always been added to the list of compiled protos, rather than 105 # requiring all clients to include it. 106 if transfer_pb2 not in self.protos: 107 self.protos.append(transfer_pb2) 108 109 self.client: client_utils.RpcClient 110 if use_hdlc_encoding: 111 channels = [ 112 pw_rpc.Channel(self.channel_id, rpc.channel_output(write)) 113 ] 114 115 def create_frame_handler_wrapper( 116 handler: Callable[[bytes, Any], Any] 117 ) -> Callable[[Frame], Any]: 118 def handler_wrapper(frame: Frame): 119 handler(frame.data, self) 120 121 return handler_wrapper 122 123 extra_frame_handlers_wrapper: rpc.FrameHandlers = {} 124 if extra_frame_handlers is not None: 125 for address, handler in extra_frame_handlers.items(): 126 extra_frame_handlers_wrapper[ 127 address 128 ] = create_frame_handler_wrapper(handler) 129 130 self.client = rpc.HdlcRpcClient( 131 reader, 132 self.protos, 133 channels, 134 detokenize_and_log_output, 135 client_impl=callback_client_impl, 136 extra_frame_handlers=extra_frame_handlers_wrapper, 137 ) 138 else: 139 channel = pw_rpc.Channel(self.channel_id, write) 140 self.client = client_utils.NoEncodingSingleChannelRpcClient( 141 reader, 142 self.protos, 143 channel, 144 client_impl=callback_client_impl, 145 ) 146 147 if use_rpc_logging: 148 # Create the log decoder used by the LogStreamHandler. 149 150 def decoded_log_handler(log: log_decoder.Log) -> None: 151 log_decoder.log_decoded_log(log, self.logger) 152 153 self._log_decoder = log_decoder.LogStreamDecoder( 154 decoded_log_handler=decoded_log_handler, 155 detokenizer=self.detokenizer, 156 source_name='RpcDevice', 157 timestamp_parser=( 158 timestamp_decoder 159 if timestamp_decoder 160 else log_decoder.timestamp_parser_ns_since_boot 161 ), 162 ) 163 164 # Start listening to logs as soon as possible. 165 self.log_stream_handler = rpc_log_stream.LogStreamHandler( 166 self.rpcs, self._log_decoder 167 ) 168 self.log_stream_handler.start_logging() 169 170 # Create the transfer manager 171 self.transfer_service = self.rpcs.pw.transfer.Transfer 172 self.transfer_manager = pw_transfer.Manager( 173 self.transfer_service, 174 default_response_timeout_s=self.rpc_timeout_s, 175 initial_response_timeout_s=self.rpc_timeout_s, 176 default_protocol_version=pw_transfer.ProtocolVersion.LATEST, 177 ) 178 179 def __enter__(self): 180 return self 181 182 def __exit__(self, *exc_info): 183 self.close() 184 185 def close(self) -> None: 186 self.client.close() 187 188 def info(self) -> console_tools.ClientInfo: 189 return console_tools.ClientInfo('device', self.rpcs, self.client.client) 190 191 @property 192 def rpcs(self) -> Any: 193 """Returns an object for accessing services on the specified channel.""" 194 return next(iter(self.client.client.channels())).rpcs 195 196 def run_tests(self, timeout_s: float | None = 5) -> TestRecord: 197 """Runs the unit tests on this device.""" 198 return pw_unit_test_run_tests(self.rpcs, timeout_s=timeout_s) 199 200 def echo(self, msg: str) -> str: 201 """Sends a string to the device and back, returning the result.""" 202 return self.rpcs.pw.rpc.EchoService.Echo(msg=msg).unwrap_or_raise().msg 203 204 def reboot(self): 205 """Triggers a reboot to run asynchronously on the device. 206 207 This function *does not* wait for the reboot to complete.""" 208 # `invoke` rather than call in order to ignore the result. No result 209 # will be sent when the device reboots. 210 self.rpcs.pw.system.proto.DeviceService.Reboot.invoke() 211 212 def crash(self): 213 """Triggers a crash to run asynchronously on the device. 214 215 This function *does not* wait for the crash to complete.""" 216 # `invoke` rather than call in order to ignore the result. No result 217 # will be sent when the device crashes. 218 self.rpcs.pw.system.proto.DeviceService.Crash.invoke() 219 220 def get_and_log_metrics(self) -> dict: 221 """Retrieves the parsed metrics and logs them to the console.""" 222 metrics = metric_parser.parse_metrics( 223 self.rpcs, self.detokenizer, self.rpc_timeout_s 224 ) 225 226 def print_metrics(metrics, path): 227 """Traverses dictionaries, until a non-dict value is reached.""" 228 for path_name, metric in metrics.items(): 229 if isinstance(metric, dict): 230 print_metrics(metric, path + '/' + path_name) 231 else: 232 _LOG.info('%s/%s: %s', path, path_name, str(metric)) 233 234 print_metrics(metrics, '') 235 return metrics 236 237 def snapshot_peak_stack_usage(self, thread_name: str | None = None): 238 snapshot_service = self.rpcs.pw.thread.proto.ThreadSnapshotService 239 _, rsp = snapshot_service.GetPeakStackUsage(name=thread_name) 240 241 thread_info = thread_pb2.SnapshotThreadInfo() 242 for thread_info_block in rsp: 243 for thread in thread_info_block.threads: 244 thread_info.threads.append(thread) 245 for line in str( 246 thread_analyzer.ThreadSnapshotAnalyzer(thread_info) 247 ).splitlines(): 248 _LOG.info('%s', line) 249 250 def list_files(self) -> Sequence[file_pb2.ListResponse]: 251 """Lists all files on this device. 252 Returns: 253 A sequence of responses from the List() RPC. 254 """ 255 fs_service = self.rpcs.pw.file.FileSystem 256 stream_response = fs_service.List() 257 if not stream_response.status.ok(): 258 _LOG.error('Failed to list files %s', stream_response.status) 259 return [] 260 261 return stream_response.responses 262 263 def delete_file(self, path: str) -> bool: 264 """Delete a file on this device. 265 Args: 266 path: The path of the file to delete. 267 Returns: 268 True on successful deletion, False on failure. 269 """ 270 271 fs_service = self.rpcs.pw.file.FileSystem 272 req = file_pb2.DeleteRequest(path=path) 273 stream_response = fs_service.Delete(req) 274 if not stream_response.status.ok(): 275 _LOG.error( 276 'Failed to delete file %s file: %s', 277 path, 278 stream_response.status, 279 ) 280 return False 281 282 return True 283 284 def transfer_file(self, file_id: int, dest_path: str) -> bool: 285 """Transfer a file on this device to the host. 286 Args: 287 file_id: The file_id of the file to transfer from device. 288 dest_path: The destination path to save the file to on the host. 289 Returns: 290 True on successful transfer, False on failure. 291 Raises: 292 pw_transfer.Error the transfer failed. 293 """ 294 try: 295 data = self.transfer_manager.read(file_id) 296 with open(dest_path, "wb") as bin_file: 297 bin_file.write(data) 298 _LOG.info( 299 'Successfully wrote file to %s', os.path.abspath(dest_path) 300 ) 301 except pw_transfer.Error: 302 _LOG.exception('Failed to transfer file_id %i', file_id) 303 return False 304 305 return True 306 307 def get_crash_snapshots(self, crash_log_path: str | None = None) -> bool: 308 r"""Transfer any crash snapshots on this device to the host. 309 Args: 310 crash_log_path: The host path to store the crash files. 311 If not specified, defaults to `/tmp` or `C:\TEMP` on Windows. 312 Returns: 313 True on successful download of snapshot, or no snapshots 314 on device. False on failure to download snapshot. 315 """ 316 if crash_log_path is None: 317 crash_log_path = tempfile.gettempdir() 318 319 snapshot_paths: list[file_pb2.Path] = [] 320 for response in self.list_files(): 321 for snapshot_path in response.paths: 322 if snapshot_path.path.startswith('/snapshots/crash_'): 323 snapshot_paths.append(snapshot_path) 324 325 if len(snapshot_paths) == 0: 326 _LOG.info('No crash snapshot on the device.') 327 return True 328 329 for snapshot_path in snapshot_paths: 330 dest_snapshot = os.path.join( 331 crash_log_path, os.path.basename(snapshot_path.path) 332 ) 333 if not self.transfer_file(snapshot_path.file_id, dest_snapshot): 334 return False 335 336 decoded_snapshot: str 337 with open(dest_snapshot, 'rb') as f: 338 decoded_snapshot = snapshot.decode_snapshot( 339 self.detokenizer, f.read() 340 ) 341 342 dest_text_snapshot = dest_snapshot.replace(".snapshot", ".txt") 343 with open(dest_text_snapshot, 'w') as f: 344 f.write(decoded_snapshot) 345 _LOG.info('Wrote crash snapshot to: %s', dest_text_snapshot) 346 347 if not self.delete_file(snapshot_path.path): 348 return False 349 350 return True 351