1#!/usr/bin/env python3 2# Copyright 2023 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""device module unit tests""" 16 17from contextlib import contextmanager 18import logging 19import queue 20import threading 21import time 22import unittest 23 24from pw_stream import stream_readers 25 26 27class QueueFile: 28 """A fake file object backed by a queue for testing.""" 29 30 EOF = object() 31 32 def __init__(self): 33 # Operator puts; consumer gets 34 self._q = queue.Queue() 35 36 # Consumer side access only! 37 self._readbuf = b'' 38 self._eof = False 39 40 ############### 41 # Consumer side 42 43 def __enter__(self): 44 return self 45 46 def __exit__(self, *exc_info): 47 self.close() 48 49 def _read_from_buf(self, size: int) -> bytes: 50 data = self._readbuf[:size] 51 self._readbuf = self._readbuf[size:] 52 return data 53 54 def read(self, size: int = 1) -> bytes: 55 """Reads data from the queue""" 56 # First try to get buffered data 57 data = self._read_from_buf(size) 58 assert len(data) <= size 59 size -= len(data) 60 61 # if size == 0: 62 if data: 63 return data 64 65 # No more data in the buffer 66 assert not self._readbuf 67 68 if self._eof: 69 return data # may be empty 70 71 # Not enough in the buffer; block on the queue 72 item = self._q.get() 73 74 # NOTE: We can't call Queue.task_done() here because the reader hasn't 75 # actually *acted* on the read item yet. 76 77 # Queued data 78 if isinstance(item, bytes): 79 self._readbuf = item 80 return self._read_from_buf(size) 81 82 # Queued exception 83 if isinstance(item, Exception): 84 raise item 85 86 # Report EOF 87 if item is self.EOF: 88 self._eof = True 89 return data # may be empty 90 91 raise Exception('unexpected item type') 92 93 def write(self, data: bytes) -> None: 94 pass 95 96 ##################### 97 # Weird middle ground 98 99 # It is a violation of most file-like object APIs for one thread to call 100 # close() while another thread is calling read(). The behavior is 101 # undefined. 102 # 103 # - On Linux, close() may wake up a select(), leaving the caller with a bad 104 # file descriptor (which could get reused!) 105 # - Or the read() could continue to block indefinitely. 106 # 107 # We choose to cause a subsequent/parallel read to receive an exception. 108 def close(self) -> None: 109 self.cause_read_exc(Exception('closed')) 110 111 ############### 112 # Operator side 113 114 def put_read_data(self, data: bytes) -> None: 115 self._q.put(data) 116 117 def cause_read_exc(self, exc: Exception) -> None: 118 self._q.put(exc) 119 120 def set_read_eof(self) -> None: 121 self._q.put(self.EOF) 122 123 def wait_for_drain(self, timeout=None) -> None: 124 """Wait for the queue to drain (be fully consumed). 125 126 Args: 127 timeout: The maximum time (in seconds) to wait, or wait forever 128 if None. 129 130 Raises: 131 TimeoutError: If timeout is given and has elapsed. 132 """ 133 # It would be great to use Queue.join() here, but that requires the 134 # consumer to call Queue.task_done(), and we can't do that because 135 # the consumer of read() doesn't know anything about it. 136 # Instead, we poll. ¯\_(ツ)_/¯ 137 start_time = time.time() 138 while not self._q.empty(): 139 if timeout is not None: 140 elapsed = time.time() - start_time 141 if elapsed > timeout: 142 raise TimeoutError(f"Queue not empty after {elapsed} sec") 143 time.sleep(0.1) 144 145 146class QueueFileTest(unittest.TestCase): 147 """Test the QueueFile class""" 148 149 def test_read_data(self) -> None: 150 file = QueueFile() 151 file.put_read_data(b'hello') 152 self.assertEqual(file.read(5), b'hello') 153 154 def test_read_data_multi_read(self) -> None: 155 file = QueueFile() 156 file.put_read_data(b'helloworld') 157 self.assertEqual(file.read(5), b'hello') 158 self.assertEqual(file.read(5), b'world') 159 160 def test_read_data_multi_put(self) -> None: 161 file = QueueFile() 162 file.put_read_data(b'hello') 163 file.put_read_data(b'world') 164 self.assertEqual(file.read(5), b'hello') 165 self.assertEqual(file.read(5), b'world') 166 167 def test_read_eof(self) -> None: 168 file = QueueFile() 169 file.set_read_eof() 170 result = file.read(5) 171 self.assertEqual(result, b'') 172 173 def test_read_exception(self) -> None: 174 file = QueueFile() 175 message = 'test exception' 176 file.cause_read_exc(ValueError(message)) 177 with self.assertRaisesRegex(ValueError, message): 178 file.read(5) 179 180 def test_wait_for_drain_works(self) -> None: 181 file = QueueFile() 182 file.put_read_data(b'hello') 183 file.read() 184 try: 185 # Timeout is arbitrary; will return immediately. 186 file.wait_for_drain(0.1) 187 except TimeoutError: 188 self.fail("wait_for_drain raised TimeoutError") 189 190 def test_wait_for_drain_raises(self) -> None: 191 file = QueueFile() 192 file.put_read_data(b'hello') 193 # don't read 194 with self.assertRaises(TimeoutError): 195 # Timeout is arbitrary; it will raise no matter what. 196 file.wait_for_drain(0.1) 197 198 199class Sentinel: 200 def __repr__(self): 201 return 'Sentinel' 202 203 204class _QueueReader(stream_readers.CancellableReader): 205 def cancel_read(self) -> None: 206 self._base_obj.close() 207 208 209def on_read_error(exc: Exception) -> None: 210 logger = logging.getLogger('pw_stream.stream_readers') 211 logger.error('data reader encountered an error', exc_info=exc) 212 213 214def _null_data_processor(data): 215 del data 216 217 218def _null_frame_handler(frame): 219 del frame 220 221 222class _ScopedReaderAndExecutor(stream_readers.DataReaderAndExecutor): 223 """""" 224 225 def __enter__(self): 226 self.start() 227 return self 228 229 def __exit__(self, *exc_info): 230 self.stop() 231 232 233# This should take <10ms but we'll wait up to 1000x longer. 234_QUEUE_DRAIN_TIMEOUT = 10.0 235 236 237class DataReaderAndExecutorTest(unittest.TestCase): 238 """Tests the DataReaderAndExecutor class.""" 239 240 # NOTE: There is no test here for stream EOF because Serial.read() 241 # can return an empty result if configured with timeout != None. 242 # The reader thread will continue in this case. 243 244 def test_clean_close_after_stream_close(self) -> None: 245 """Assert RpcClient closes cleanly when stream closes.""" 246 # See b/293595266. 247 file = QueueFile() 248 249 with self.assert_no_stream_stream_readers_error_logs(): 250 with file: 251 with _ScopedReaderAndExecutor( 252 reader=_QueueReader(file), 253 on_read_error=on_read_error, 254 data_processor=_null_data_processor, 255 frame_handler=_null_frame_handler, 256 ): 257 # We want to make sure the reader thread is blocked on 258 # read() and doesn't exit immediately. 259 file.put_read_data(b'') 260 file.wait_for_drain(_QUEUE_DRAIN_TIMEOUT) 261 262 # RpcClient.__exit__ calls stop() on the reader thread, but 263 # it is blocked on file.read(). 264 265 # QueueFile.close() is called, triggering an exception in the 266 # blocking read() (by implementation choice). The reader should 267 # handle it by *not* logging it and exiting immediately. 268 269 self.assert_no_background_threads_running() 270 271 def test_device_handles_read_exception(self) -> None: 272 """Assert RpcClient closes cleanly when read raises an exception.""" 273 # See b/293595266. 274 file = QueueFile() 275 276 logger = logging.getLogger('pw_stream.stream_readers') 277 test_exc = Exception('boom') 278 with self.assertLogs(logger, level=logging.ERROR) as ctx: 279 with _ScopedReaderAndExecutor( 280 reader=_QueueReader(file), 281 on_read_error=on_read_error, 282 data_processor=_null_data_processor, 283 frame_handler=_null_frame_handler, 284 ): 285 # Cause read() to raise an exception. The reader should 286 # handle it by logging it and exiting immediately. 287 file.cause_read_exc(test_exc) 288 file.wait_for_drain(_QUEUE_DRAIN_TIMEOUT) 289 290 # Assert one exception was raised 291 self.assertEqual(len(ctx.records), 1) 292 rec = ctx.records[0] 293 self.assertIsNotNone(rec.exc_info) 294 assert rec.exc_info is not None # for mypy 295 self.assertEqual(rec.exc_info[1], test_exc) 296 297 self.assert_no_background_threads_running() 298 299 @contextmanager 300 def assert_no_stream_stream_readers_error_logs(self): 301 logger = logging.getLogger('pw_stream.stream_readers') 302 sentinel = Sentinel() 303 with self.assertLogs(logger, level=logging.ERROR) as ctx: 304 # TODO: b/294861320 - use assertNoLogs() in Python 3.10+ 305 # We actually want to assert there are no errors, but 306 # TestCase.assertNoLogs() is not available until Python 3.10. 307 # So we log one error to keep the test from failing and manually 308 # inspect the list of captured records. 309 logger.error(sentinel) 310 311 yield ctx 312 313 self.assertEqual([record.msg for record in ctx.records], [sentinel]) 314 315 def assert_no_background_threads_running(self): 316 self.assertEqual(threading.enumerate(), [threading.current_thread()]) 317 318 319if __name__ == '__main__': 320 unittest.main() 321