xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio_tests/tests/unit/_server_ssl_cert_config_test.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2017 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests server certificate rotation.
15
16Here we test various aspects of gRPC Python, and in some cases gRPC
17Core by extension, support for server certificate rotation.
18
19* ServerSSLCertReloadTestWithClientAuth: test ability to rotate
20  server's SSL cert for use in future channels with clients while not
21  affecting any existing channel. The server requires client
22  authentication.
23
24* ServerSSLCertReloadTestWithoutClientAuth: like
25  ServerSSLCertReloadTestWithClientAuth except that the server does
26  not authenticate the client.
27
28* ServerSSLCertReloadTestCertConfigReuse: tests gRPC Python's ability
29  to deal with user's reuse of ServerCertificateConfiguration instances.
30"""
31
32import abc
33import collections
34from concurrent import futures
35import logging
36import os
37import threading
38import unittest
39
40import grpc
41
42from tests.testing import _application_common
43from tests.testing import _server_application
44from tests.testing.proto import services_pb2_grpc
45from tests.unit import resources
46from tests.unit import test_common
47
48CA_1_PEM = resources.cert_hier_1_root_ca_cert()
49CA_2_PEM = resources.cert_hier_2_root_ca_cert()
50
51CLIENT_KEY_1_PEM = resources.cert_hier_1_client_1_key()
52CLIENT_CERT_CHAIN_1_PEM = (
53    resources.cert_hier_1_client_1_cert()
54    + resources.cert_hier_1_intermediate_ca_cert()
55)
56
57CLIENT_KEY_2_PEM = resources.cert_hier_2_client_1_key()
58CLIENT_CERT_CHAIN_2_PEM = (
59    resources.cert_hier_2_client_1_cert()
60    + resources.cert_hier_2_intermediate_ca_cert()
61)
62
63SERVER_KEY_1_PEM = resources.cert_hier_1_server_1_key()
64SERVER_CERT_CHAIN_1_PEM = (
65    resources.cert_hier_1_server_1_cert()
66    + resources.cert_hier_1_intermediate_ca_cert()
67)
68
69SERVER_KEY_2_PEM = resources.cert_hier_2_server_1_key()
70SERVER_CERT_CHAIN_2_PEM = (
71    resources.cert_hier_2_server_1_cert()
72    + resources.cert_hier_2_intermediate_ca_cert()
73)
74
75# for use with the CertConfigFetcher. Roughly a simple custom mock
76# implementation
77Call = collections.namedtuple("Call", ["did_raise", "returned_cert_config"])
78
79
80def _create_channel(port, credentials):
81    return grpc.secure_channel("localhost:{}".format(port), credentials)
82
83
84def _create_client_stub(channel, expect_success):
85    if expect_success:
86        # per Nathaniel: there's some robustness issue if we start
87        # using a channel without waiting for it to be actually ready
88        grpc.channel_ready_future(channel).result(timeout=10)
89    return services_pb2_grpc.FirstServiceStub(channel)
90
91
92class CertConfigFetcher(object):
93    def __init__(self):
94        self._lock = threading.Lock()
95        self._calls = []
96        self._should_raise = False
97        self._cert_config = None
98
99    def reset(self):
100        with self._lock:
101            self._calls = []
102            self._should_raise = False
103            self._cert_config = None
104
105    def configure(self, should_raise, cert_config):
106        assert not (should_raise and cert_config), (
107            "should not specify both should_raise and a cert_config at the same"
108            " time"
109        )
110        with self._lock:
111            self._should_raise = should_raise
112            self._cert_config = cert_config
113
114    def getCalls(self):
115        with self._lock:
116            return self._calls
117
118    def __call__(self):
119        with self._lock:
120            if self._should_raise:
121                self._calls.append(Call(True, None))
122                raise ValueError("just for fun, should not affect the test")
123            else:
124                self._calls.append(Call(False, self._cert_config))
125                return self._cert_config
126
127
128class _ServerSSLCertReloadTest(unittest.TestCase, metaclass=abc.ABCMeta):
129    def __init__(self, *args, **kwargs):
130        super(_ServerSSLCertReloadTest, self).__init__(*args, **kwargs)
131        self.server = None
132        self.port = None
133
134    @abc.abstractmethod
135    def require_client_auth(self):
136        raise NotImplementedError()
137
138    def setUp(self):
139        self.server = test_common.test_server()
140        services_pb2_grpc.add_FirstServiceServicer_to_server(
141            _server_application.FirstServiceServicer(), self.server
142        )
143        switch_cert_on_client_num = 10
144        initial_cert_config = grpc.ssl_server_certificate_configuration(
145            [(SERVER_KEY_1_PEM, SERVER_CERT_CHAIN_1_PEM)],
146            root_certificates=CA_2_PEM,
147        )
148        self.cert_config_fetcher = CertConfigFetcher()
149        server_credentials = grpc.dynamic_ssl_server_credentials(
150            initial_cert_config,
151            self.cert_config_fetcher,
152            require_client_authentication=self.require_client_auth(),
153        )
154        self.port = self.server.add_secure_port("[::]:0", server_credentials)
155        self.server.start()
156
157    def tearDown(self):
158        if self.server:
159            self.server.stop(None)
160
161    def _perform_rpc(self, client_stub, expect_success):
162        # we don't care about the actual response of the rpc; only
163        # whether we can perform it or not, and if not, the status
164        # code must be UNAVAILABLE
165        request = _application_common.UNARY_UNARY_REQUEST
166        if expect_success:
167            response = client_stub.UnUn(request)
168            self.assertEqual(response, _application_common.UNARY_UNARY_RESPONSE)
169        else:
170            with self.assertRaises(grpc.RpcError) as exception_context:
171                client_stub.UnUn(request)
172            # If TLS 1.2 is used, then the client receives an alert message
173            # before the handshake is complete, so the status is UNAVAILABLE. If
174            # TLS 1.3 is used, then the client receives the alert message after
175            # the handshake is complete, so the TSI handshaker returns the
176            # TSI_PROTOCOL_FAILURE result. This result does not have a
177            # corresponding status code, so this yields an UNKNOWN status.
178            self.assertTrue(
179                exception_context.exception.code()
180                in [grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN]
181            )
182
183    def _do_one_shot_client_rpc(
184        self,
185        expect_success,
186        root_certificates=None,
187        private_key=None,
188        certificate_chain=None,
189    ):
190        credentials = grpc.ssl_channel_credentials(
191            root_certificates=root_certificates,
192            private_key=private_key,
193            certificate_chain=certificate_chain,
194        )
195        with _create_channel(self.port, credentials) as client_channel:
196            client_stub = _create_client_stub(client_channel, expect_success)
197            self._perform_rpc(client_stub, expect_success)
198
199    def _test(self):
200        # things should work...
201        self.cert_config_fetcher.configure(False, None)
202        self._do_one_shot_client_rpc(
203            True,
204            root_certificates=CA_1_PEM,
205            private_key=CLIENT_KEY_2_PEM,
206            certificate_chain=CLIENT_CERT_CHAIN_2_PEM,
207        )
208        actual_calls = self.cert_config_fetcher.getCalls()
209        self.assertEqual(len(actual_calls), 1)
210        self.assertFalse(actual_calls[0].did_raise)
211        self.assertIsNone(actual_calls[0].returned_cert_config)
212
213        # client should reject server...
214        # fails because client trusts ca2 and so will reject server
215        self.cert_config_fetcher.reset()
216        self.cert_config_fetcher.configure(False, None)
217        self._do_one_shot_client_rpc(
218            False,
219            root_certificates=CA_2_PEM,
220            private_key=CLIENT_KEY_2_PEM,
221            certificate_chain=CLIENT_CERT_CHAIN_2_PEM,
222        )
223        actual_calls = self.cert_config_fetcher.getCalls()
224        self.assertGreaterEqual(len(actual_calls), 1)
225        self.assertFalse(actual_calls[0].did_raise)
226        for i, call in enumerate(actual_calls):
227            self.assertFalse(call.did_raise, "i= {}".format(i))
228            self.assertIsNone(call.returned_cert_config, "i= {}".format(i))
229
230        # should work again...
231        self.cert_config_fetcher.reset()
232        self.cert_config_fetcher.configure(True, None)
233        self._do_one_shot_client_rpc(
234            True,
235            root_certificates=CA_1_PEM,
236            private_key=CLIENT_KEY_2_PEM,
237            certificate_chain=CLIENT_CERT_CHAIN_2_PEM,
238        )
239        actual_calls = self.cert_config_fetcher.getCalls()
240        self.assertEqual(len(actual_calls), 1)
241        self.assertTrue(actual_calls[0].did_raise)
242        self.assertIsNone(actual_calls[0].returned_cert_config)
243
244        # if with_client_auth, then client should be rejected by
245        # server because client uses key/cert1, but server trusts ca2,
246        # so server will reject
247        self.cert_config_fetcher.reset()
248        self.cert_config_fetcher.configure(False, None)
249        self._do_one_shot_client_rpc(
250            not self.require_client_auth(),
251            root_certificates=CA_1_PEM,
252            private_key=CLIENT_KEY_1_PEM,
253            certificate_chain=CLIENT_CERT_CHAIN_1_PEM,
254        )
255        actual_calls = self.cert_config_fetcher.getCalls()
256        self.assertGreaterEqual(len(actual_calls), 1)
257        for i, call in enumerate(actual_calls):
258            self.assertFalse(call.did_raise, "i= {}".format(i))
259            self.assertIsNone(call.returned_cert_config, "i= {}".format(i))
260
261        # should work again...
262        self.cert_config_fetcher.reset()
263        self.cert_config_fetcher.configure(False, None)
264        self._do_one_shot_client_rpc(
265            True,
266            root_certificates=CA_1_PEM,
267            private_key=CLIENT_KEY_2_PEM,
268            certificate_chain=CLIENT_CERT_CHAIN_2_PEM,
269        )
270        actual_calls = self.cert_config_fetcher.getCalls()
271        self.assertEqual(len(actual_calls), 1)
272        self.assertFalse(actual_calls[0].did_raise)
273        self.assertIsNone(actual_calls[0].returned_cert_config)
274
275        # now create the "persistent" clients
276        self.cert_config_fetcher.reset()
277        self.cert_config_fetcher.configure(False, None)
278        channel_A = _create_channel(
279            self.port,
280            grpc.ssl_channel_credentials(
281                root_certificates=CA_1_PEM,
282                private_key=CLIENT_KEY_2_PEM,
283                certificate_chain=CLIENT_CERT_CHAIN_2_PEM,
284            ),
285        )
286        persistent_client_stub_A = _create_client_stub(channel_A, True)
287        self._perform_rpc(persistent_client_stub_A, True)
288        actual_calls = self.cert_config_fetcher.getCalls()
289        self.assertEqual(len(actual_calls), 1)
290        self.assertFalse(actual_calls[0].did_raise)
291        self.assertIsNone(actual_calls[0].returned_cert_config)
292
293        self.cert_config_fetcher.reset()
294        self.cert_config_fetcher.configure(False, None)
295        channel_B = _create_channel(
296            self.port,
297            grpc.ssl_channel_credentials(
298                root_certificates=CA_1_PEM,
299                private_key=CLIENT_KEY_2_PEM,
300                certificate_chain=CLIENT_CERT_CHAIN_2_PEM,
301            ),
302        )
303        persistent_client_stub_B = _create_client_stub(channel_B, True)
304        self._perform_rpc(persistent_client_stub_B, True)
305        actual_calls = self.cert_config_fetcher.getCalls()
306        self.assertEqual(len(actual_calls), 1)
307        self.assertFalse(actual_calls[0].did_raise)
308        self.assertIsNone(actual_calls[0].returned_cert_config)
309
310        # moment of truth!! client should reject server because the
311        # server switch cert...
312        cert_config = grpc.ssl_server_certificate_configuration(
313            [(SERVER_KEY_2_PEM, SERVER_CERT_CHAIN_2_PEM)],
314            root_certificates=CA_1_PEM,
315        )
316        self.cert_config_fetcher.reset()
317        self.cert_config_fetcher.configure(False, cert_config)
318        self._do_one_shot_client_rpc(
319            False,
320            root_certificates=CA_1_PEM,
321            private_key=CLIENT_KEY_2_PEM,
322            certificate_chain=CLIENT_CERT_CHAIN_2_PEM,
323        )
324        actual_calls = self.cert_config_fetcher.getCalls()
325        self.assertGreaterEqual(len(actual_calls), 1)
326        self.assertFalse(actual_calls[0].did_raise)
327        for i, call in enumerate(actual_calls):
328            self.assertFalse(call.did_raise, "i= {}".format(i))
329            self.assertEqual(
330                call.returned_cert_config, cert_config, "i= {}".format(i)
331            )
332
333        # now should work again...
334        self.cert_config_fetcher.reset()
335        self.cert_config_fetcher.configure(False, None)
336        self._do_one_shot_client_rpc(
337            True,
338            root_certificates=CA_2_PEM,
339            private_key=CLIENT_KEY_1_PEM,
340            certificate_chain=CLIENT_CERT_CHAIN_1_PEM,
341        )
342        actual_calls = self.cert_config_fetcher.getCalls()
343        self.assertEqual(len(actual_calls), 1)
344        self.assertFalse(actual_calls[0].did_raise)
345        self.assertIsNone(actual_calls[0].returned_cert_config)
346
347        # client should be rejected by server if with_client_auth
348        self.cert_config_fetcher.reset()
349        self.cert_config_fetcher.configure(False, None)
350        self._do_one_shot_client_rpc(
351            not self.require_client_auth(),
352            root_certificates=CA_2_PEM,
353            private_key=CLIENT_KEY_2_PEM,
354            certificate_chain=CLIENT_CERT_CHAIN_2_PEM,
355        )
356        actual_calls = self.cert_config_fetcher.getCalls()
357        self.assertGreaterEqual(len(actual_calls), 1)
358        for i, call in enumerate(actual_calls):
359            self.assertFalse(call.did_raise, "i= {}".format(i))
360            self.assertIsNone(call.returned_cert_config, "i= {}".format(i))
361
362        # here client should reject server...
363        self.cert_config_fetcher.reset()
364        self.cert_config_fetcher.configure(False, None)
365        self._do_one_shot_client_rpc(
366            False,
367            root_certificates=CA_1_PEM,
368            private_key=CLIENT_KEY_2_PEM,
369            certificate_chain=CLIENT_CERT_CHAIN_2_PEM,
370        )
371        actual_calls = self.cert_config_fetcher.getCalls()
372        self.assertGreaterEqual(len(actual_calls), 1)
373        for i, call in enumerate(actual_calls):
374            self.assertFalse(call.did_raise, "i= {}".format(i))
375            self.assertIsNone(call.returned_cert_config, "i= {}".format(i))
376
377        # persistent clients should continue to work
378        self.cert_config_fetcher.reset()
379        self.cert_config_fetcher.configure(False, None)
380        self._perform_rpc(persistent_client_stub_A, True)
381        actual_calls = self.cert_config_fetcher.getCalls()
382        self.assertEqual(len(actual_calls), 0)
383
384        self.cert_config_fetcher.reset()
385        self.cert_config_fetcher.configure(False, None)
386        self._perform_rpc(persistent_client_stub_B, True)
387        actual_calls = self.cert_config_fetcher.getCalls()
388        self.assertEqual(len(actual_calls), 0)
389
390        channel_A.close()
391        channel_B.close()
392
393
394class ServerSSLCertConfigFetcherParamsChecks(unittest.TestCase):
395    def test_check_on_initial_config(self):
396        with self.assertRaises(TypeError):
397            grpc.dynamic_ssl_server_credentials(None, str)
398        with self.assertRaises(TypeError):
399            grpc.dynamic_ssl_server_credentials(1, str)
400
401    def test_check_on_config_fetcher(self):
402        cert_config = grpc.ssl_server_certificate_configuration(
403            [(SERVER_KEY_2_PEM, SERVER_CERT_CHAIN_2_PEM)],
404            root_certificates=CA_1_PEM,
405        )
406        with self.assertRaises(TypeError):
407            grpc.dynamic_ssl_server_credentials(cert_config, None)
408        with self.assertRaises(TypeError):
409            grpc.dynamic_ssl_server_credentials(cert_config, 1)
410
411
412class ServerSSLCertReloadTestWithClientAuth(_ServerSSLCertReloadTest):
413    def require_client_auth(self):
414        return True
415
416    test = _ServerSSLCertReloadTest._test
417
418
419class ServerSSLCertReloadTestWithoutClientAuth(_ServerSSLCertReloadTest):
420    def require_client_auth(self):
421        return False
422
423    test = _ServerSSLCertReloadTest._test
424
425
426class ServerSSLCertReloadTestCertConfigReuse(_ServerSSLCertReloadTest):
427    """Ensures that `ServerCertificateConfiguration` instances can be reused.
428
429    Because gRPC Core takes ownership of the
430    `grpc_ssl_server_certificate_config` encapsulated by
431    `ServerCertificateConfiguration`, this test reuses the same
432    `ServerCertificateConfiguration` instances multiple times to make sure
433    gRPC Python takes care of maintaining the validity of
434    `ServerCertificateConfiguration` instances, so that such instances can be
435    re-used by user application.
436    """
437
438    def require_client_auth(self):
439        return True
440
441    def setUp(self):
442        self.server = test_common.test_server()
443        services_pb2_grpc.add_FirstServiceServicer_to_server(
444            _server_application.FirstServiceServicer(), self.server
445        )
446        self.cert_config_A = grpc.ssl_server_certificate_configuration(
447            [(SERVER_KEY_1_PEM, SERVER_CERT_CHAIN_1_PEM)],
448            root_certificates=CA_2_PEM,
449        )
450        self.cert_config_B = grpc.ssl_server_certificate_configuration(
451            [(SERVER_KEY_2_PEM, SERVER_CERT_CHAIN_2_PEM)],
452            root_certificates=CA_1_PEM,
453        )
454        self.cert_config_fetcher = CertConfigFetcher()
455        server_credentials = grpc.dynamic_ssl_server_credentials(
456            self.cert_config_A,
457            self.cert_config_fetcher,
458            require_client_authentication=True,
459        )
460        self.port = self.server.add_secure_port("[::]:0", server_credentials)
461        self.server.start()
462
463    def test_cert_config_reuse(self):
464        # succeed with A
465        self.cert_config_fetcher.reset()
466        self.cert_config_fetcher.configure(False, self.cert_config_A)
467        self._do_one_shot_client_rpc(
468            True,
469            root_certificates=CA_1_PEM,
470            private_key=CLIENT_KEY_2_PEM,
471            certificate_chain=CLIENT_CERT_CHAIN_2_PEM,
472        )
473        actual_calls = self.cert_config_fetcher.getCalls()
474        self.assertEqual(len(actual_calls), 1)
475        self.assertFalse(actual_calls[0].did_raise)
476        self.assertEqual(
477            actual_calls[0].returned_cert_config, self.cert_config_A
478        )
479
480        # fail with A
481        self.cert_config_fetcher.reset()
482        self.cert_config_fetcher.configure(False, self.cert_config_A)
483        self._do_one_shot_client_rpc(
484            False,
485            root_certificates=CA_2_PEM,
486            private_key=CLIENT_KEY_1_PEM,
487            certificate_chain=CLIENT_CERT_CHAIN_1_PEM,
488        )
489        actual_calls = self.cert_config_fetcher.getCalls()
490        self.assertGreaterEqual(len(actual_calls), 1)
491        self.assertFalse(actual_calls[0].did_raise)
492        for i, call in enumerate(actual_calls):
493            self.assertFalse(call.did_raise, "i= {}".format(i))
494            self.assertEqual(
495                call.returned_cert_config, self.cert_config_A, "i= {}".format(i)
496            )
497
498        # succeed again with A
499        self.cert_config_fetcher.reset()
500        self.cert_config_fetcher.configure(False, self.cert_config_A)
501        self._do_one_shot_client_rpc(
502            True,
503            root_certificates=CA_1_PEM,
504            private_key=CLIENT_KEY_2_PEM,
505            certificate_chain=CLIENT_CERT_CHAIN_2_PEM,
506        )
507        actual_calls = self.cert_config_fetcher.getCalls()
508        self.assertEqual(len(actual_calls), 1)
509        self.assertFalse(actual_calls[0].did_raise)
510        self.assertEqual(
511            actual_calls[0].returned_cert_config, self.cert_config_A
512        )
513
514        # succeed with B
515        self.cert_config_fetcher.reset()
516        self.cert_config_fetcher.configure(False, self.cert_config_B)
517        self._do_one_shot_client_rpc(
518            True,
519            root_certificates=CA_2_PEM,
520            private_key=CLIENT_KEY_1_PEM,
521            certificate_chain=CLIENT_CERT_CHAIN_1_PEM,
522        )
523        actual_calls = self.cert_config_fetcher.getCalls()
524        self.assertEqual(len(actual_calls), 1)
525        self.assertFalse(actual_calls[0].did_raise)
526        self.assertEqual(
527            actual_calls[0].returned_cert_config, self.cert_config_B
528        )
529
530        # fail with B
531        self.cert_config_fetcher.reset()
532        self.cert_config_fetcher.configure(False, self.cert_config_B)
533        self._do_one_shot_client_rpc(
534            False,
535            root_certificates=CA_1_PEM,
536            private_key=CLIENT_KEY_2_PEM,
537            certificate_chain=CLIENT_CERT_CHAIN_2_PEM,
538        )
539        actual_calls = self.cert_config_fetcher.getCalls()
540        self.assertGreaterEqual(len(actual_calls), 1)
541        self.assertFalse(actual_calls[0].did_raise)
542        for i, call in enumerate(actual_calls):
543            self.assertFalse(call.did_raise, "i= {}".format(i))
544            self.assertEqual(
545                call.returned_cert_config, self.cert_config_B, "i= {}".format(i)
546            )
547
548        # succeed again with B
549        self.cert_config_fetcher.reset()
550        self.cert_config_fetcher.configure(False, self.cert_config_B)
551        self._do_one_shot_client_rpc(
552            True,
553            root_certificates=CA_2_PEM,
554            private_key=CLIENT_KEY_1_PEM,
555            certificate_chain=CLIENT_CERT_CHAIN_1_PEM,
556        )
557        actual_calls = self.cert_config_fetcher.getCalls()
558        self.assertEqual(len(actual_calls), 1)
559        self.assertFalse(actual_calls[0].did_raise)
560        self.assertEqual(
561            actual_calls[0].returned_cert_config, self.cert_config_B
562        )
563
564
565if __name__ == "__main__":
566    logging.basicConfig()
567    unittest.main(verbosity=2)
568