1#!/usr/bin/env python3 2# Copyright 2023 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Facilitates automating unit tests on devices with serial ports. 16 17This library assumes that the on-device test runner emits the test results 18as plain-text over a serial port, and that tests are triggered by a pre-defined 19input (``DEFAULT_TEST_START_CHARACTER``) over the same serial port that results 20are emitted from. 21""" 22 23import abc 24import logging 25from pathlib import Path 26 27import serial # type: ignore 28 29 30_LOG = logging.getLogger("serial_test_runner") 31 32# Verification of test pass/failure depends on these strings. If the formatting 33# or output of the simple_printing_event_handler changes, this may need to be 34# updated. 35_TESTS_STARTING_STRING = b'[==========] Running all tests.' 36_TESTS_DONE_STRING = b'[==========] Done running all tests.' 37_TEST_FAILURE_STRING = b'[ FAILED ]' 38 39# Character used to trigger test start. 40DEFAULT_TEST_START_CHARACTER = ' '.encode('utf-8') 41 42 43class FlashingFailure(Exception): 44 """A simple exception to be raised when flashing fails.""" 45 46 47class TestingFailure(Exception): 48 """A simple exception to be raised when a testing step fails.""" 49 50 51class DeviceNotFound(Exception): 52 """A simple exception to be raised when unable to connect to a device.""" 53 54 55class SerialTestingDevice(abc.ABC): 56 """A device that supports automated testing via parsing serial output.""" 57 58 @abc.abstractmethod 59 def load_binary(self, binary: Path) -> bool: 60 """Flashes the specified binary to this device. 61 62 Raises: 63 DeviceNotFound: This device is no longer available. 64 FlashingFailure: The binary could not be flashed. 65 66 Returns: 67 True if the binary was loaded successfully. 68 """ 69 70 @abc.abstractmethod 71 def serial_port(self) -> str: 72 """Returns the name of the com port this device is enumerated on. 73 74 Raises: 75 DeviceNotFound: This device is no longer available. 76 """ 77 78 @abc.abstractmethod 79 def baud_rate(self) -> int: 80 """Returns the baud rate to use when connecting to this device. 81 82 Raises: 83 DeviceNotFound: This device is no longer available. 84 """ 85 86 87def _log_subprocess_output(level, output: bytes, logger: logging.Logger): 88 """Logs subprocess output line-by-line.""" 89 90 lines = output.decode('utf-8', errors='replace').splitlines() 91 for line in lines: 92 logger.log(level, line) 93 94 95def trigger_test_run( 96 port: str, 97 baud_rate: int, 98 test_timeout: float, 99 trigger_data: bytes = DEFAULT_TEST_START_CHARACTER, 100) -> bytes: 101 """Triggers a test run, and returns captured test results.""" 102 103 serial_data = bytearray() 104 device = serial.Serial(baudrate=baud_rate, port=port, timeout=test_timeout) 105 if not device.is_open: 106 raise TestingFailure('Failed to open device') 107 108 # Flush input buffer and trigger the test start. 109 device.reset_input_buffer() 110 device.write(trigger_data) 111 112 # Block and wait for the first byte. 113 serial_data += device.read() 114 if not serial_data: 115 raise TestingFailure('Device not producing output') 116 117 # Read with a reasonable timeout until we stop getting characters. 118 while True: 119 bytes_read = device.readline() 120 if not bytes_read: 121 break 122 serial_data += bytes_read 123 if serial_data.rfind(_TESTS_DONE_STRING) != -1: 124 # Set to much more aggressive timeout since the last one or two 125 # lines should print out immediately. (one line if all fails or all 126 # passes, two lines if mixed.) 127 device.timeout = 0.01 128 129 # Remove carriage returns. 130 serial_data = serial_data.replace(b"\r", b"") 131 132 # Try to trim captured results to only contain most recent test run. 133 test_start_index = serial_data.rfind(_TESTS_STARTING_STRING) 134 return ( 135 serial_data 136 if test_start_index == -1 137 else serial_data[test_start_index:] 138 ) 139 140 141def handle_test_results( 142 test_output: bytes, logger: logging.Logger = _LOG 143) -> None: 144 """Parses test output to determine whether tests passed or failed. 145 146 Raises: 147 TestingFailure if any tests fail or if test results are incomplete. 148 """ 149 150 if test_output.find(_TESTS_STARTING_STRING) == -1: 151 raise TestingFailure('Failed to find test start') 152 153 if test_output.rfind(_TESTS_DONE_STRING) == -1: 154 _log_subprocess_output(logging.INFO, test_output, logger) 155 raise TestingFailure('Tests did not complete') 156 157 if test_output.rfind(_TEST_FAILURE_STRING) != -1: 158 _log_subprocess_output(logging.INFO, test_output, logger) 159 raise TestingFailure('Test suite had one or more failures') 160 161 _log_subprocess_output(logging.DEBUG, test_output, logger) 162 163 logger.info('Test passed!') 164 165 166def run_device_test( 167 device: SerialTestingDevice, 168 binary: Path, 169 test_timeout: float, 170 logger: logging.Logger = _LOG, 171) -> bool: 172 """Runs tests on a device. 173 174 When a unit test run fails, results are logged as an error. 175 176 Args: 177 device: The device to run tests on. 178 binary: The binary containing tests to flash on the device. 179 test_timeout: If the device stops producing output longer than this 180 timeout, the test is considered stuck and is aborted. 181 182 Returns: 183 True if all tests passed. 184 """ 185 186 logger.info('Flashing binary to device') 187 device.load_binary(binary) 188 try: 189 logger.info('Running test') 190 test_output = trigger_test_run( 191 device.serial_port(), device.baud_rate(), test_timeout 192 ) 193 if test_output: 194 handle_test_results(test_output, logger) 195 except TestingFailure as err: 196 logger.error(err) 197 return False 198 199 return True 200