1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Tools for generating Pigweed tests that execute in C++ and Python.""" 15 16import argparse 17from dataclasses import dataclass 18from datetime import datetime 19from collections import defaultdict 20import unittest 21 22from typing import ( 23 Any, 24 Callable, 25 Generic, 26 Iterable, 27 Iterator, 28 Sequence, 29 TextIO, 30 TypeVar, 31 Union, 32) 33 34_COPYRIGHT = f"""\ 35// Copyright {datetime.now().year} The Pigweed Authors 36// 37// Licensed under the Apache License, Version 2.0 (the "License"); you may not 38// use this file except in compliance with the License. You may obtain a copy of 39// the License at 40// 41// https://www.apache.org/licenses/LICENSE-2.0 42// 43// Unless required by applicable law or agreed to in writing, software 44// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 45// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 46// License for the specific language governing permissions and limitations under 47// the License. 48 49// AUTOGENERATED - DO NOT EDIT 50// 51// Generated at {datetime.now().isoformat()} 52""" 53 54_HEADER_CPP = ( 55 _COPYRIGHT 56 + """\ 57// clang-format off 58""" 59) 60 61_HEADER_JS = ( 62 _COPYRIGHT 63 + """\ 64/* eslint-env browser, jasmine */ 65""" 66) 67 68 69class Error(Exception): 70 """Something went wrong when generating tests.""" 71 72 73T = TypeVar('T') 74 75 76@dataclass 77class Context(Generic[T]): 78 """Info passed into test generator functions for each test case.""" 79 80 group: str 81 count: int 82 total: int 83 test_case: T 84 85 def cc_name(self) -> str: 86 name = ''.join( 87 w.capitalize() for w in self.group.replace('-', ' ').split(' ') 88 ) 89 name = ''.join(c if c.isalnum() else '_' for c in name) 90 return f'{name}_{self.count}' if self.total > 1 else name 91 92 def py_name(self) -> str: 93 name = 'test_' + ''.join( 94 c if c.isalnum() else '_' for c in self.group.lower() 95 ) 96 return f'{name}_{self.count}' if self.total > 1 else name 97 98 def ts_name(self) -> str: 99 name = ''.join(c if c.isalnum() else ' ' for c in self.group.lower()) 100 return f'{name} {self.count}' if self.total > 1 else name 101 102 103# Test cases are specified as a sequence of strings or test case instances. The 104# strings are used to separate the tests into named groups. For example: 105# 106# STR_SPLIT_TEST_CASES = ( 107# 'Empty input', 108# MyTestCase('', '', []), 109# MyTestCase('', 'foo', []), 110# 'Split on single character', 111# MyTestCase('abcde', 'c', ['ab', 'de']), 112# ... 113# ) 114# 115GroupOrTest = Union[str, T] 116 117# Python tests are generated by a function that returns a function usable as a 118# unittest.TestCase method. 119PyTest = Callable[[unittest.TestCase], None] 120PyTestGenerator = Callable[[Context[T]], PyTest] 121 122# C++ tests are generated with a function that returns or yields lines of C++ 123# code for the given test case. 124CcTestGenerator = Callable[[Context[T]], Iterable[str]] 125 126JsTestGenerator = Callable[[Context[T]], Iterable[str]] 127 128 129class TestGenerator(Generic[T]): 130 """Generates tests for multiple languages from a series of test cases.""" 131 132 def __init__(self, test_cases: Sequence[GroupOrTest[T]]): 133 self._cases: dict[str, list[T]] = defaultdict(list) 134 message = '' 135 136 if len(test_cases) < 2: 137 raise Error('At least one test case must be provided') 138 139 if not isinstance(test_cases[0], str): 140 raise Error( 141 'The first item in the test cases must be a group name string' 142 ) 143 144 for case in test_cases: 145 if isinstance(case, str): 146 message = case 147 else: 148 self._cases[message].append(case) 149 150 if '' in self._cases: 151 raise Error('Empty test group names are not permitted') 152 153 def _test_contexts(self) -> Iterator[Context[T]]: 154 for group, test_list in self._cases.items(): 155 for i, test_case in enumerate(test_list, 1): 156 yield Context(group, i, len(test_list), test_case) 157 158 def _generate_python_tests(self, define_py_test: PyTestGenerator): 159 tests: dict[str, Callable[[Any], None]] = {} 160 161 for ctx in self._test_contexts(): 162 test = define_py_test(ctx) 163 test.__name__ = ctx.py_name() 164 165 if test.__name__ in tests: 166 raise Error(f'Multiple Python tests are named {test.__name__}!') 167 168 tests[test.__name__] = test 169 170 return tests 171 172 def python_tests(self, name: str, define_py_test: PyTestGenerator) -> type: 173 """Returns a Python unittest.TestCase class with tests for each case.""" 174 return type( 175 name, 176 (unittest.TestCase,), 177 self._generate_python_tests(define_py_test), 178 ) 179 180 def _generate_cc_tests( 181 self, define_cpp_test: CcTestGenerator, header: str, footer: str 182 ) -> Iterator[str]: 183 yield _HEADER_CPP 184 yield header 185 186 for ctx in self._test_contexts(): 187 yield from define_cpp_test(ctx) 188 yield '' 189 190 yield footer 191 192 def cc_tests( 193 self, 194 output: TextIO, 195 define_cpp_test: CcTestGenerator, 196 header: str, 197 footer: str, 198 ): 199 """Writes C++ unit tests for each test case to the given file.""" 200 for line in self._generate_cc_tests(define_cpp_test, header, footer): 201 output.write(line) 202 output.write('\n') 203 204 def _generate_ts_tests( 205 self, define_ts_test: JsTestGenerator, header: str, footer: str 206 ) -> Iterator[str]: 207 yield _HEADER_JS 208 yield header 209 210 for ctx in self._test_contexts(): 211 yield from define_ts_test(ctx) 212 yield footer 213 214 def ts_tests( 215 self, 216 output: TextIO, 217 define_js_test: JsTestGenerator, 218 header: str, 219 footer: str, 220 ): 221 """Writes JS unit tests for each test case to the given file.""" 222 for line in self._generate_ts_tests(define_js_test, header, footer): 223 output.write(line) 224 output.write('\n') 225 226 227def _to_chars(data: bytes) -> Iterator[str]: 228 for i, byte in enumerate(data): 229 try: 230 char = data[i : i + 1].decode() 231 yield char if char.isprintable() else fr'\x{byte:02x}' 232 except UnicodeDecodeError: 233 yield fr'\x{byte:02x}' 234 235 236def cc_string(data: str | bytes) -> str: 237 """Returns a C++ string literal version of a byte string or UTF-8 string.""" 238 if isinstance(data, str): 239 data = data.encode() 240 241 return '"' + ''.join(_to_chars(data)) + '"' 242 243 244def parse_test_generation_args() -> argparse.Namespace: 245 parser = argparse.ArgumentParser(description='Generate unit test files') 246 parser.add_argument( 247 '--generate-cc-test', 248 type=argparse.FileType('w'), 249 help='Generate the C++ test file', 250 ) 251 parser.add_argument( 252 '--generate-ts-test', 253 type=argparse.FileType('w'), 254 help='Generate the JS test file', 255 ) 256 return parser.parse_known_args()[0] 257