xref: /aosp_15_r20/external/pigweed/pw_hdlc/py/pw_hdlc/rpc.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Utilities for using HDLC with ``pw_rpc``."""
15
16from __future__ import annotations
17
18import io
19import logging
20import queue
21import sys
22import threading
23import time
24from typing import (
25    Any,
26    BinaryIO,
27    Callable,
28    Iterable,
29    Sequence,
30    TypeVar,
31)
32
33import pw_rpc
34from pw_rpc import client_utils
35from pw_hdlc.decode import Frame, FrameDecoder
36from pw_hdlc import encode
37from pw_stream import stream_readers
38
39_LOG = logging.getLogger('pw_hdlc.rpc')
40_VERBOSE = logging.DEBUG - 1
41
42
43# Aliases for objects moved to their proper place in pw_rpc formerly defined
44# here. This is for backwards compatibility.
45CancellableReader = stream_readers.CancellableReader
46SelectableReader = stream_readers.SelectableReader
47SocketReader = stream_readers.SocketReader
48SerialReader = stream_readers.SerialReader
49DataReaderAndExecutor = stream_readers.DataReaderAndExecutor
50PathsModulesOrProtoLibrary = client_utils.PathsModulesOrProtoLibrary
51RpcClient = client_utils.RpcClient
52NoEncodingSingleChannelRpcClient = client_utils.NoEncodingSingleChannelRpcClient
53SocketSubprocess = stream_readers.SocketSubprocess
54FrameTypeT = TypeVar('FrameTypeT')
55
56
57# Default values for channel using HDLC encoding.
58DEFAULT_CHANNEL_ID = 1
59DEFAULT_ADDRESS = ord('R')
60STDOUT_ADDRESS = 1
61
62FrameHandlers = dict[int, Callable[[Frame], Any]]
63
64
65# Default channel output for using HDLC encoding.
66def channel_output(
67    writer: Callable[[bytes], Any],
68    address: int = DEFAULT_ADDRESS,
69    delay_s: float = 0,
70) -> Callable[[bytes], None]:
71    """
72    Returns a function that can be used as a channel output for ``pw_rpc``.
73    """
74
75    if delay_s:
76
77        def slow_write(data: bytes) -> None:
78            """Slows down writes in case unbuffered serial is in use."""
79            for byte in data:
80                time.sleep(delay_s)
81                writer(bytes([byte]))
82
83        return lambda data: slow_write(encode.ui_frame(address, data))
84
85    def write_hdlc(data: bytes):
86        frame = encode.ui_frame(address, data)
87        _LOG.log(_VERBOSE, 'Write %2d B: %s', len(frame), frame)
88        writer(frame)
89
90    return write_hdlc
91
92
93def default_channels(write: Callable[[bytes], Any]) -> list[pw_rpc.Channel]:
94    """Default Channel with HDLC encoding."""
95    return [pw_rpc.Channel(DEFAULT_CHANNEL_ID, channel_output(write))]
96
97
98# Writes to stdout by default, but sys.stdout.buffer is not guaranteed to exist
99# (see https://docs.python.org/3/library/io.html#io.TextIOBase.buffer). Defer
100# to sys.__stdout__.buffer if sys.stdout is wrapped with something that does not
101# offer it.
102def write_to_file(
103    data: bytes,
104    output: BinaryIO = getattr(sys.stdout, 'buffer', sys.__stdout__.buffer),
105) -> None:
106    output.write(data + b'\n')
107    output.flush()
108
109
110class HdlcRpcClient(client_utils.RpcClient):
111    """An RPC client configured to run over HDLC.
112
113    Expects HDLC frames to have addresses that dictate how to parse the HDLC
114    payloads.
115    """
116
117    def __init__(
118        self,
119        reader: stream_readers.CancellableReader,
120        paths_or_modules: client_utils.PathsModulesOrProtoLibrary,
121        channels: Iterable[pw_rpc.Channel],
122        output: Callable[[bytes], Any] = write_to_file,
123        client_impl: pw_rpc.client.ClientImpl | None = None,
124        *,
125        _incoming_packet_filter_for_testing: (
126            pw_rpc.ChannelManipulator | None
127        ) = None,
128        rpc_frames_address: int = DEFAULT_ADDRESS,
129        log_frames_address: int = STDOUT_ADDRESS,
130        extra_frame_handlers: FrameHandlers | None = None,
131    ):
132        """Creates an RPC client configured to communicate using HDLC.
133
134        Args:
135          reader: Readable object used to receive RPC packets.
136          paths_or_modules: paths to .proto files or proto modules.
137          channels: RPC channels to use for output.
138          output: where to write ``stdout`` output from the device.
139          client_impl: The RPC Client implementation. Defaults to the callback
140            client implementation if not provided.
141          rpc_frames_address: the address used in the HDLC frames for RPC
142            packets. This can be the channel ID, or any custom address.
143          log_frames_address: the address used in the HDLC frames for ``stdout``
144            output from the device.
145          extra_fram_handlers: Optional mapping of HDLC frame addresses to their
146            callbacks.
147        """
148        # Set up frame handling.
149        rpc_output: Callable[[bytes], Any] = self.handle_rpc_packet
150        if _incoming_packet_filter_for_testing is not None:
151            _incoming_packet_filter_for_testing.send_packet = rpc_output
152            rpc_output = _incoming_packet_filter_for_testing
153
154        frame_handlers: FrameHandlers = {
155            rpc_frames_address: lambda frame: rpc_output(frame.data),
156            log_frames_address: lambda frame: output(frame.data),
157        }
158        if extra_frame_handlers:
159            frame_handlers.update(extra_frame_handlers)
160
161        def handle_frame(frame: Frame) -> None:
162            # Suppress raising any frame errors to avoid crashes on data
163            # processing, which may hide or drop other data.
164            try:
165                if not frame.ok():
166                    _LOG.error('Failed to parse frame: %s', frame.status.value)
167                    _LOG.debug('%s', frame.data)
168                    return
169
170                try:
171                    frame_handlers[frame.address](frame)
172                except KeyError:
173                    _LOG.warning(
174                        'Unhandled frame for address %d: %s',
175                        frame.address,
176                        frame,
177                    )
178            except:  # pylint: disable=bare-except
179                _LOG.exception('Exception in HDLC frame handler thread')
180
181        decoder = FrameDecoder()
182
183        def on_read_error(exc: Exception) -> None:
184            _LOG.error('data reader encountered an error', exc_info=exc)
185
186        reader_and_executor = stream_readers.DataReaderAndExecutor(
187            reader, on_read_error, decoder.process_valid_frames, handle_frame
188        )
189        super().__init__(
190            reader_and_executor, paths_or_modules, channels, client_impl
191        )
192
193
194class HdlcRpcLocalServerAndClient:
195    """Runs an RPC server in a subprocess and connects to it over a socket.
196
197    This can be used to run a local RPC server in an integration test.
198    """
199
200    def __init__(
201        self,
202        server_command: Sequence,
203        port: int,
204        protos: client_utils.PathsModulesOrProtoLibrary,
205        *,
206        incoming_processor: pw_rpc.ChannelManipulator | None = None,
207        outgoing_processor: pw_rpc.ChannelManipulator | None = None,
208    ) -> None:
209        """Creates a new ``HdlcRpcLocalServerAndClient``."""
210
211        self.server = stream_readers.SocketSubprocess(server_command, port)
212
213        self._bytes_queue: queue.SimpleQueue[bytes] = queue.SimpleQueue()
214        self._read_thread = threading.Thread(target=self._read_from_socket)
215        self._read_thread.start()
216
217        self.output = io.BytesIO()
218
219        self.channel_output: Any = self.server.socket.sendall
220
221        self._incoming_processor = incoming_processor
222        if outgoing_processor is not None:
223            outgoing_processor.send_packet = self.channel_output
224            self.channel_output = outgoing_processor
225
226        class QueueReader(stream_readers.CancellableReader):
227            def read(self) -> bytes:
228                try:
229                    return self._base_obj.get(timeout=3)
230                except queue.Empty:
231                    return b''
232
233            def cancel_read(self) -> None:
234                pass
235
236        self._rpc_client = HdlcRpcClient(
237            QueueReader(self._bytes_queue),
238            protos,
239            default_channels(self.channel_output),
240            self.output.write,
241            _incoming_packet_filter_for_testing=incoming_processor,
242        )
243        self.client = self._rpc_client.client
244
245    def _read_from_socket(self):
246        while True:
247            data = self.server.socket.recv(4096)
248            self._bytes_queue.put(data)
249            if not data:
250                return
251
252    def close(self):
253        self.server.close()
254        self.output.close()
255        self._rpc_client.close()
256        self._read_thread.join()
257
258    def __enter__(self) -> HdlcRpcLocalServerAndClient:
259        return self
260
261    def __exit__(self, exc_type, exc_value, traceback) -> None:
262        self.close()
263