xref: /aosp_15_r20/external/pigweed/pw_stream/py/stream_readers_test.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1#!/usr/bin/env python3
2# Copyright 2023 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""device module unit tests"""
16
17from contextlib import contextmanager
18import logging
19import queue
20import threading
21import time
22import unittest
23
24from pw_stream import stream_readers
25
26
27class QueueFile:
28    """A fake file object backed by a queue for testing."""
29
30    EOF = object()
31
32    def __init__(self):
33        # Operator puts; consumer gets
34        self._q = queue.Queue()
35
36        # Consumer side access only!
37        self._readbuf = b''
38        self._eof = False
39
40    ###############
41    # Consumer side
42
43    def __enter__(self):
44        return self
45
46    def __exit__(self, *exc_info):
47        self.close()
48
49    def _read_from_buf(self, size: int) -> bytes:
50        data = self._readbuf[:size]
51        self._readbuf = self._readbuf[size:]
52        return data
53
54    def read(self, size: int = 1) -> bytes:
55        """Reads data from the queue"""
56        # First try to get buffered data
57        data = self._read_from_buf(size)
58        assert len(data) <= size
59        size -= len(data)
60
61        # if size == 0:
62        if data:
63            return data
64
65        # No more data in the buffer
66        assert not self._readbuf
67
68        if self._eof:
69            return data  # may be empty
70
71        # Not enough in the buffer; block on the queue
72        item = self._q.get()
73
74        # NOTE: We can't call Queue.task_done() here because the reader hasn't
75        # actually *acted* on the read item yet.
76
77        # Queued data
78        if isinstance(item, bytes):
79            self._readbuf = item
80            return self._read_from_buf(size)
81
82        # Queued exception
83        if isinstance(item, Exception):
84            raise item
85
86        # Report EOF
87        if item is self.EOF:
88            self._eof = True
89            return data  # may be empty
90
91        raise Exception('unexpected item type')
92
93    def write(self, data: bytes) -> None:
94        pass
95
96    #####################
97    # Weird middle ground
98
99    # It is a violation of most file-like object APIs for one thread to call
100    # close() while another thread is calling read(). The behavior is
101    # undefined.
102    #
103    # - On Linux, close() may wake up a select(), leaving the caller with a bad
104    #   file descriptor (which could get reused!)
105    # - Or the read() could continue to block indefinitely.
106    #
107    # We choose to cause a subsequent/parallel read to receive an exception.
108    def close(self) -> None:
109        self.cause_read_exc(Exception('closed'))
110
111    ###############
112    # Operator side
113
114    def put_read_data(self, data: bytes) -> None:
115        self._q.put(data)
116
117    def cause_read_exc(self, exc: Exception) -> None:
118        self._q.put(exc)
119
120    def set_read_eof(self) -> None:
121        self._q.put(self.EOF)
122
123    def wait_for_drain(self, timeout=None) -> None:
124        """Wait for the queue to drain (be fully consumed).
125
126        Args:
127          timeout: The maximum time (in seconds) to wait, or wait forever
128            if None.
129
130        Raises:
131          TimeoutError: If timeout is given and has elapsed.
132        """
133        # It would be great to use Queue.join() here, but that requires the
134        # consumer to call Queue.task_done(), and we can't do that because
135        # the consumer of read() doesn't know anything about it.
136        # Instead, we poll.  ¯\_(ツ)_/¯
137        start_time = time.time()
138        while not self._q.empty():
139            if timeout is not None:
140                elapsed = time.time() - start_time
141                if elapsed > timeout:
142                    raise TimeoutError(f"Queue not empty after {elapsed} sec")
143            time.sleep(0.1)
144
145
146class QueueFileTest(unittest.TestCase):
147    """Test the QueueFile class"""
148
149    def test_read_data(self) -> None:
150        file = QueueFile()
151        file.put_read_data(b'hello')
152        self.assertEqual(file.read(5), b'hello')
153
154    def test_read_data_multi_read(self) -> None:
155        file = QueueFile()
156        file.put_read_data(b'helloworld')
157        self.assertEqual(file.read(5), b'hello')
158        self.assertEqual(file.read(5), b'world')
159
160    def test_read_data_multi_put(self) -> None:
161        file = QueueFile()
162        file.put_read_data(b'hello')
163        file.put_read_data(b'world')
164        self.assertEqual(file.read(5), b'hello')
165        self.assertEqual(file.read(5), b'world')
166
167    def test_read_eof(self) -> None:
168        file = QueueFile()
169        file.set_read_eof()
170        result = file.read(5)
171        self.assertEqual(result, b'')
172
173    def test_read_exception(self) -> None:
174        file = QueueFile()
175        message = 'test exception'
176        file.cause_read_exc(ValueError(message))
177        with self.assertRaisesRegex(ValueError, message):
178            file.read(5)
179
180    def test_wait_for_drain_works(self) -> None:
181        file = QueueFile()
182        file.put_read_data(b'hello')
183        file.read()
184        try:
185            # Timeout is arbitrary; will return immediately.
186            file.wait_for_drain(0.1)
187        except TimeoutError:
188            self.fail("wait_for_drain raised TimeoutError")
189
190    def test_wait_for_drain_raises(self) -> None:
191        file = QueueFile()
192        file.put_read_data(b'hello')
193        # don't read
194        with self.assertRaises(TimeoutError):
195            # Timeout is arbitrary; it will raise no matter what.
196            file.wait_for_drain(0.1)
197
198
199class Sentinel:
200    def __repr__(self):
201        return 'Sentinel'
202
203
204class _QueueReader(stream_readers.CancellableReader):
205    def cancel_read(self) -> None:
206        self._base_obj.close()
207
208
209def on_read_error(exc: Exception) -> None:
210    logger = logging.getLogger('pw_stream.stream_readers')
211    logger.error('data reader encountered an error', exc_info=exc)
212
213
214def _null_data_processor(data):
215    del data
216
217
218def _null_frame_handler(frame):
219    del frame
220
221
222class _ScopedReaderAndExecutor(stream_readers.DataReaderAndExecutor):
223    """"""
224
225    def __enter__(self):
226        self.start()
227        return self
228
229    def __exit__(self, *exc_info):
230        self.stop()
231
232
233# This should take <10ms but we'll wait up to 1000x longer.
234_QUEUE_DRAIN_TIMEOUT = 10.0
235
236
237class DataReaderAndExecutorTest(unittest.TestCase):
238    """Tests the DataReaderAndExecutor class."""
239
240    # NOTE: There is no test here for stream EOF because Serial.read()
241    # can return an empty result if configured with timeout != None.
242    # The reader thread will continue in this case.
243
244    def test_clean_close_after_stream_close(self) -> None:
245        """Assert RpcClient closes cleanly when stream closes."""
246        # See b/293595266.
247        file = QueueFile()
248
249        with self.assert_no_stream_stream_readers_error_logs():
250            with file:
251                with _ScopedReaderAndExecutor(
252                    reader=_QueueReader(file),
253                    on_read_error=on_read_error,
254                    data_processor=_null_data_processor,
255                    frame_handler=_null_frame_handler,
256                ):
257                    # We want to make sure the reader thread is blocked on
258                    # read() and doesn't exit immediately.
259                    file.put_read_data(b'')
260                    file.wait_for_drain(_QUEUE_DRAIN_TIMEOUT)
261
262                # RpcClient.__exit__ calls stop() on the reader thread, but
263                # it is blocked on file.read().
264
265            # QueueFile.close() is called, triggering an exception in the
266            # blocking read() (by implementation choice). The reader should
267            # handle it by *not* logging it and exiting immediately.
268
269        self.assert_no_background_threads_running()
270
271    def test_device_handles_read_exception(self) -> None:
272        """Assert RpcClient closes cleanly when read raises an exception."""
273        # See b/293595266.
274        file = QueueFile()
275
276        logger = logging.getLogger('pw_stream.stream_readers')
277        test_exc = Exception('boom')
278        with self.assertLogs(logger, level=logging.ERROR) as ctx:
279            with _ScopedReaderAndExecutor(
280                reader=_QueueReader(file),
281                on_read_error=on_read_error,
282                data_processor=_null_data_processor,
283                frame_handler=_null_frame_handler,
284            ):
285                # Cause read() to raise an exception. The reader should
286                # handle it by logging it and exiting immediately.
287                file.cause_read_exc(test_exc)
288                file.wait_for_drain(_QUEUE_DRAIN_TIMEOUT)
289
290        # Assert one exception was raised
291        self.assertEqual(len(ctx.records), 1)
292        rec = ctx.records[0]
293        self.assertIsNotNone(rec.exc_info)
294        assert rec.exc_info is not None  # for mypy
295        self.assertEqual(rec.exc_info[1], test_exc)
296
297        self.assert_no_background_threads_running()
298
299    @contextmanager
300    def assert_no_stream_stream_readers_error_logs(self):
301        logger = logging.getLogger('pw_stream.stream_readers')
302        sentinel = Sentinel()
303        with self.assertLogs(logger, level=logging.ERROR) as ctx:
304            # TODO: b/294861320 - use assertNoLogs() in Python 3.10+
305            # We actually want to assert there are no errors, but
306            # TestCase.assertNoLogs() is not available until Python 3.10.
307            # So we log one error to keep the test from failing and manually
308            # inspect the list of captured records.
309            logger.error(sentinel)
310
311            yield ctx
312
313        self.assertEqual([record.msg for record in ctx.records], [sentinel])
314
315    def assert_no_background_threads_running(self):
316        self.assertEqual(threading.enumerate(), [threading.current_thread()])
317
318
319if __name__ == '__main__':
320    unittest.main()
321