xref: /aosp_15_r20/external/pigweed/pw_tls_client/py/pw_tls_client/generate_test_data.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Generate test data
15
16Generate data needed for unit tests, i.e. certificates, keys, and CRLSet.
17"""
18
19import argparse
20import subprocess
21import sys
22from datetime import datetime, timedelta
23
24from cryptography import x509
25from cryptography.hazmat.primitives import hashes
26from cryptography.hazmat.primitives import serialization
27from cryptography.hazmat.primitives.asymmetric import rsa
28from cryptography.x509.oid import NameOID
29
30CERTS_AND_KEYS_HEADER = """// Copyright 2021 The Pigweed Authors
31//
32// Licensed under the Apache License, Version 2.0 (the "License"); you may not
33// use this file except in compliance with the License. You may obtain a copy
34// of the License at
35//
36//     https://www.apache.org/licenses/LICENSE-2.0
37//
38// Unless required by applicable law or agreed to in writing, software
39// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
40// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
41// License for the specific language governing permissions and limitations under
42// the License.
43
44#pragma once
45
46#include "pw_bytes/span.h"
47
48"""
49
50
51class Subject:
52    """A subject wraps a name, private key and extensions for issuers
53    to issue its certificate"""
54
55    def __init__(
56        self, name: str, extensions: list[tuple[x509.ExtensionType, bool]]
57    ):
58        self._subject_name = x509.Name(
59            [
60                x509.NameAttribute(NameOID.COUNTRY_NAME, u"US"),
61                x509.NameAttribute(
62                    NameOID.STATE_OR_PROVINCE_NAME, u"California"
63                ),
64                x509.NameAttribute(NameOID.LOCALITY_NAME, u"Mountain View"),
65                x509.NameAttribute(NameOID.ORGANIZATION_NAME, name),
66                x509.NameAttribute(NameOID.COMMON_NAME, u"Google-Pigweed"),
67            ]
68        )
69        self._private_key = rsa.generate_private_key(
70            public_exponent=65537, key_size=2048
71        )
72        self._extensions = extensions
73
74    def subject_name(self) -> x509.Name:
75        """Returns the subject name"""
76        return self._subject_name
77
78    def public_key(self) -> rsa.RSAPublicKey:
79        """Returns the public key of this subject"""
80        return self._private_key.public_key()
81
82    def private_key(self) -> rsa.RSAPrivateKey:
83        """Returns the private key of this subject"""
84        return self._private_key
85
86    def extensions(self) -> list[tuple[x509.ExtensionType, bool]]:
87        """Returns the requested extensions for issuer"""
88        return self._extensions
89
90
91class CA(Subject):
92    """A CA/Sub-ca that issues certificates"""
93
94    def __init__(self, *args, **kwargs):
95        ext = [
96            (x509.BasicConstraints(True, None), True),
97            (
98                x509.KeyUsage(
99                    digital_signature=False,
100                    content_commitment=False,
101                    key_encipherment=False,
102                    data_encipherment=False,
103                    key_agreement=False,
104                    crl_sign=False,
105                    encipher_only=False,
106                    decipher_only=False,
107                    key_cert_sign=True,
108                ),
109                True,
110            ),
111        ]
112        super().__init__(*args, extensions=ext, **kwargs)
113
114    def sign(
115        self, subject: Subject, not_before: datetime, not_after: datetime
116    ) -> x509.Certificate:
117        """Issues a certificate for another CA/Sub-ca/Server"""
118        builder = x509.CertificateBuilder()
119
120        # Subject name is the target's subject name
121        builder = builder.subject_name(subject.subject_name())
122
123        # Issuer name is this CA/sub-ca's subject name
124        builder = builder.issuer_name(self._subject_name)
125
126        # Public key is the target's public key.
127        builder = builder.public_key(subject.public_key())
128
129        # Validity period.
130        builder = builder.not_valid_before(not_before).not_valid_after(
131            not_after
132        )
133
134        # Uses a random serial number
135        builder = builder.serial_number(x509.random_serial_number())
136
137        # Add extensions
138        for extension, critical in subject.extensions():
139            builder = builder.add_extension(extension, critical)
140
141        # Sign and returns the certificate.
142        return builder.sign(self._private_key, hashes.SHA256())
143
144    def self_sign(
145        self, not_before: datetime, not_after: datetime
146    ) -> x509.Certificate:
147        """Issues a self sign certificate"""
148        return self.sign(self, not_before, not_after)
149
150
151class Server(Subject):
152    """The end-entity server"""
153
154    def __init__(self, *args, **kwargs):
155        ext = [
156            (x509.BasicConstraints(False, None), True),
157            (
158                x509.KeyUsage(
159                    digital_signature=True,
160                    content_commitment=False,
161                    key_encipherment=False,
162                    data_encipherment=False,
163                    key_agreement=False,
164                    crl_sign=False,
165                    encipher_only=False,
166                    decipher_only=False,
167                    key_cert_sign=False,
168                ),
169                True,
170            ),
171            (
172                x509.ExtendedKeyUsage([x509.ExtendedKeyUsageOID.SERVER_AUTH]),
173                True,
174            ),
175        ]
176        super().__init__(*args, extensions=ext, **kwargs)
177
178
179def c_escaped_string(data: bytes):
180    """Generates a C byte string representation for a byte array
181
182    For example, given a byte sequence of [0x12, 0x34, 0x56]. The function
183    generates the following byte string code:
184
185            {"\x12\x34\x56", 3}
186    """
187    body = ''.join([f'\\x{b:02x}' for b in data])
188    return f'{{\"{body}\", {len(data)}}}'
189
190
191def byte_array_declaration(data: bytes, name: str) -> str:
192    """Generates a ConstByteSpan declaration for a byte array"""
193    type_name = '[[maybe_unused]] const pw::ConstByteSpan'
194    array_body = f'pw::as_bytes(pw::span{c_escaped_string(data)})'
195    return f'{type_name} {name} = {array_body};'
196
197
198class Codegen:
199    """Base helper class for code generation"""
200
201    def generate_code(self) -> str:  # pylint: disable=no-self-use
202        """Generates C++ code for this object"""
203        return ''
204
205
206class PrivateKeyGen(Codegen):
207    """Codegen class for a private key"""
208
209    def __init__(self, key: rsa.RSAPrivateKey, name: str):
210        self._key = key
211        self._name = name
212
213    def generate_code(self) -> str:
214        """Code generation"""
215        return byte_array_declaration(
216            self._key.private_bytes(
217                serialization.Encoding.DER,
218                serialization.PrivateFormat.TraditionalOpenSSL,
219                serialization.NoEncryption(),
220            ),
221            self._name,
222        )
223
224
225class CertificateGen(Codegen):
226    """Codegen class for a single certificate"""
227
228    def __init__(self, cert: x509.Certificate, name: str):
229        self._cert = cert
230        self._name = name
231
232    def generate_code(self) -> str:
233        """Code generation"""
234        return byte_array_declaration(
235            self._cert.public_bytes(serialization.Encoding.DER), self._name
236        )
237
238
239def generate_test_data() -> str:
240    """Generates test data"""
241    subjects: list[Codegen] = []
242
243    # Working valid period.
244    # Start from yesterday, to make sure we are in the valid period.
245    not_before = datetime.utcnow() - timedelta(days=1)
246    # Valid for 1 year.
247    not_after = not_before + timedelta(days=365)
248
249    # Generate a root-A CA certificates
250    root_a = CA("root-A")
251    subjects.append(
252        CertificateGen(root_a.self_sign(not_before, not_after), "kRootACert")
253    )
254
255    # Generate a sub CA certificate signed by root-A.
256    sub = CA("sub")
257    subjects.append(
258        CertificateGen(root_a.sign(sub, not_before, not_after), "kSubCACert")
259    )
260
261    # Generate a valid server certificate signed by sub
262    server = Server("server")
263    subjects.append(
264        CertificateGen(sub.sign(server, not_before, not_after), "kServerCert")
265    )
266    subjects.append(PrivateKeyGen(server.private_key(), "kServerKey"))
267
268    root_b = CA("root-B")
269    subjects.append(
270        CertificateGen(root_b.self_sign(not_before, not_after), "kRootBCert")
271    )
272
273    code = 'namespace {\n\n'
274    for subject in subjects:
275        code += subject.generate_code() + '\n\n'
276    code += '}\n'
277
278    return code
279
280
281def clang_format(file):
282    subprocess.run(
283        [
284            "clang-format",
285            "-i",
286            file,
287        ],
288        check=True,
289    )
290
291
292def parse_args():
293    """Setup argparse."""
294    parser = argparse.ArgumentParser()
295    parser.add_argument(
296        "certs_and_keys_header",
297        help="output header file for test certificates and keys",
298    )
299    return parser.parse_args()
300
301
302def main() -> int:
303    """Main"""
304    args = parse_args()
305
306    certs_and_keys = generate_test_data()
307
308    with open(args.certs_and_keys_header, 'w') as header:
309        header.write(CERTS_AND_KEYS_HEADER)
310        header.write(certs_and_keys)
311
312    clang_format(args.certs_and_keys_header)
313    return 0
314
315
316if __name__ == "__main__":
317    sys.exit(main())
318