xref: /aosp_15_r20/external/pigweed/pw_build/py/pw_build/generated_tests.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Tools for generating Pigweed tests that execute in C++ and Python."""
15
16import argparse
17from dataclasses import dataclass
18from datetime import datetime
19from collections import defaultdict
20import unittest
21
22from typing import (
23    Any,
24    Callable,
25    Generic,
26    Iterable,
27    Iterator,
28    Sequence,
29    TextIO,
30    TypeVar,
31    Union,
32)
33
34_COPYRIGHT = f"""\
35// Copyright {datetime.now().year} The Pigweed Authors
36//
37// Licensed under the Apache License, Version 2.0 (the "License"); you may not
38// use this file except in compliance with the License. You may obtain a copy of
39// the License at
40//
41//     https://www.apache.org/licenses/LICENSE-2.0
42//
43// Unless required by applicable law or agreed to in writing, software
44// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
45// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
46// License for the specific language governing permissions and limitations under
47// the License.
48
49// AUTOGENERATED - DO NOT EDIT
50//
51// Generated at {datetime.now().isoformat()}
52"""
53
54_HEADER_CPP = (
55    _COPYRIGHT
56    + """\
57// clang-format off
58"""
59)
60
61_HEADER_JS = (
62    _COPYRIGHT
63    + """\
64/* eslint-env browser, jasmine */
65"""
66)
67
68
69class Error(Exception):
70    """Something went wrong when generating tests."""
71
72
73T = TypeVar('T')
74
75
76@dataclass
77class Context(Generic[T]):
78    """Info passed into test generator functions for each test case."""
79
80    group: str
81    count: int
82    total: int
83    test_case: T
84
85    def cc_name(self) -> str:
86        name = ''.join(
87            w.capitalize() for w in self.group.replace('-', ' ').split(' ')
88        )
89        name = ''.join(c if c.isalnum() else '_' for c in name)
90        return f'{name}_{self.count}' if self.total > 1 else name
91
92    def py_name(self) -> str:
93        name = 'test_' + ''.join(
94            c if c.isalnum() else '_' for c in self.group.lower()
95        )
96        return f'{name}_{self.count}' if self.total > 1 else name
97
98    def ts_name(self) -> str:
99        name = ''.join(c if c.isalnum() else ' ' for c in self.group.lower())
100        return f'{name} {self.count}' if self.total > 1 else name
101
102
103# Test cases are specified as a sequence of strings or test case instances. The
104# strings are used to separate the tests into named groups. For example:
105#
106#   STR_SPLIT_TEST_CASES = (
107#     'Empty input',
108#     MyTestCase('', '', []),
109#     MyTestCase('', 'foo', []),
110#     'Split on single character',
111#     MyTestCase('abcde', 'c', ['ab', 'de']),
112#     ...
113#   )
114#
115GroupOrTest = Union[str, T]
116
117# Python tests are generated by a function that returns a function usable as a
118# unittest.TestCase method.
119PyTest = Callable[[unittest.TestCase], None]
120PyTestGenerator = Callable[[Context[T]], PyTest]
121
122# C++ tests are generated with a function that returns or yields lines of C++
123# code for the given test case.
124CcTestGenerator = Callable[[Context[T]], Iterable[str]]
125
126JsTestGenerator = Callable[[Context[T]], Iterable[str]]
127
128
129class TestGenerator(Generic[T]):
130    """Generates tests for multiple languages from a series of test cases."""
131
132    def __init__(self, test_cases: Sequence[GroupOrTest[T]]):
133        self._cases: dict[str, list[T]] = defaultdict(list)
134        message = ''
135
136        if len(test_cases) < 2:
137            raise Error('At least one test case must be provided')
138
139        if not isinstance(test_cases[0], str):
140            raise Error(
141                'The first item in the test cases must be a group name string'
142            )
143
144        for case in test_cases:
145            if isinstance(case, str):
146                message = case
147            else:
148                self._cases[message].append(case)
149
150        if '' in self._cases:
151            raise Error('Empty test group names are not permitted')
152
153    def _test_contexts(self) -> Iterator[Context[T]]:
154        for group, test_list in self._cases.items():
155            for i, test_case in enumerate(test_list, 1):
156                yield Context(group, i, len(test_list), test_case)
157
158    def _generate_python_tests(self, define_py_test: PyTestGenerator):
159        tests: dict[str, Callable[[Any], None]] = {}
160
161        for ctx in self._test_contexts():
162            test = define_py_test(ctx)
163            test.__name__ = ctx.py_name()
164
165            if test.__name__ in tests:
166                raise Error(f'Multiple Python tests are named {test.__name__}!')
167
168            tests[test.__name__] = test
169
170        return tests
171
172    def python_tests(self, name: str, define_py_test: PyTestGenerator) -> type:
173        """Returns a Python unittest.TestCase class with tests for each case."""
174        return type(
175            name,
176            (unittest.TestCase,),
177            self._generate_python_tests(define_py_test),
178        )
179
180    def _generate_cc_tests(
181        self, define_cpp_test: CcTestGenerator, header: str, footer: str
182    ) -> Iterator[str]:
183        yield _HEADER_CPP
184        yield header
185
186        for ctx in self._test_contexts():
187            yield from define_cpp_test(ctx)
188            yield ''
189
190        yield footer
191
192    def cc_tests(
193        self,
194        output: TextIO,
195        define_cpp_test: CcTestGenerator,
196        header: str,
197        footer: str,
198    ):
199        """Writes C++ unit tests for each test case to the given file."""
200        for line in self._generate_cc_tests(define_cpp_test, header, footer):
201            output.write(line)
202            output.write('\n')
203
204    def _generate_ts_tests(
205        self, define_ts_test: JsTestGenerator, header: str, footer: str
206    ) -> Iterator[str]:
207        yield _HEADER_JS
208        yield header
209
210        for ctx in self._test_contexts():
211            yield from define_ts_test(ctx)
212        yield footer
213
214    def ts_tests(
215        self,
216        output: TextIO,
217        define_js_test: JsTestGenerator,
218        header: str,
219        footer: str,
220    ):
221        """Writes JS unit tests for each test case to the given file."""
222        for line in self._generate_ts_tests(define_js_test, header, footer):
223            output.write(line)
224            output.write('\n')
225
226
227def _to_chars(data: bytes) -> Iterator[str]:
228    for i, byte in enumerate(data):
229        try:
230            char = data[i : i + 1].decode()
231            yield char if char.isprintable() else fr'\x{byte:02x}'
232        except UnicodeDecodeError:
233            yield fr'\x{byte:02x}'
234
235
236def cc_string(data: str | bytes) -> str:
237    """Returns a C++ string literal version of a byte string or UTF-8 string."""
238    if isinstance(data, str):
239        data = data.encode()
240
241    return '"' + ''.join(_to_chars(data)) + '"'
242
243
244def parse_test_generation_args() -> argparse.Namespace:
245    parser = argparse.ArgumentParser(description='Generate unit test files')
246    parser.add_argument(
247        '--generate-cc-test',
248        type=argparse.FileType('w'),
249        help='Generate the C++ test file',
250    )
251    parser.add_argument(
252        '--generate-ts-test',
253        type=argparse.FileType('w'),
254        help='Generate the JS test file',
255    )
256    return parser.parse_known_args()[0]
257