1 //! A re-implementation of https://github.com/rustls/rustls/blob/v/0.20.8/rustls/tests/client_cert_verifier.rs
2 
3 use std::io;
4 use std::sync::Arc;
5 
6 use rustls::client::WebPkiVerifier;
7 use rustls::server::{ClientCertVerified, ClientCertVerifier};
8 use rustls::PrivateKey;
9 use rustls::{Certificate, DistinguishedNames, Error, ServerConfig, SignatureScheme};
10 
11 macro_rules! embed_files {
12     (
13         $(
14             ($name:ident, $keytype:expr, $path:expr);
15         )+
16     ) => {
17         $(
18             const $name: &'static [u8] = include_bytes!(
19                 concat!(env!("CARGO_MANIFEST_DIR"), "/data/ca/", $keytype, "/", $path));
20         )+
21 
22         pub fn bytes_for(keytype: &str, path: &str) -> &'static [u8] {
23             match (keytype, path) {
24                 $(
25                     ($keytype, $path) => $name,
26                 )+
27                 _ => panic!("unknown keytype {} with path {}", keytype, path),
28             }
29         }
30     }
31 }
32 
33 embed_files! {
34     (ECDSA_CA_CERT, "ecdsa", "ca.cert");
35     (ECDSA_CLIENT_FULLCHAIN, "ecdsa", "client.fullchain");
36     (ECDSA_CLIENT_KEY, "ecdsa", "client.key");
37     (ECDSA_END_FULLCHAIN, "ecdsa", "end.fullchain");
38     (ECDSA_END_KEY, "ecdsa", "end.key");
39 
40     (EDDSA_CA_CERT, "eddsa", "ca.cert");
41     (EDDSA_CLIENT_FULLCHAIN, "eddsa", "client.fullchain");
42     (EDDSA_CLIENT_KEY, "eddsa", "client.key");
43     (EDDSA_END_FULLCHAIN, "eddsa", "end.fullchain");
44     (EDDSA_END_KEY, "eddsa", "end.key");
45 
46     (RSA_CA_CERT, "rsa", "ca.cert");
47     (RSA_CLIENT_FULLCHAIN, "rsa", "client.fullchain");
48     (RSA_CLIENT_KEY, "rsa", "client.key");
49     (RSA_END_FULLCHAIN, "rsa", "end.fullchain");
50     (RSA_END_KEY, "rsa", "end.key");
51 }
52 
53 #[derive(Clone, Copy, PartialEq)]
54 pub enum KeyType {
55     Rsa,
56     Ecdsa,
57     Ed25519,
58 }
59 
60 pub static ALL_KEY_TYPES: [KeyType; 3] = [KeyType::Rsa, KeyType::Ecdsa, KeyType::Ed25519];
61 
62 impl KeyType {
bytes_for(&self, part: &str) -> &'static [u8]63     fn bytes_for(&self, part: &str) -> &'static [u8] {
64         match self {
65             Self::Rsa => bytes_for("rsa", part),
66             Self::Ecdsa => bytes_for("ecdsa", part),
67             Self::Ed25519 => bytes_for("eddsa", part),
68         }
69     }
70 
get_chain(&self) -> Vec<Certificate>71     pub fn get_chain(&self) -> Vec<Certificate> {
72         rustls_pemfile::certs(&mut io::BufReader::new(self.bytes_for("end.fullchain")))
73             .unwrap()
74             .iter()
75             .map(|v| Certificate(v.clone()))
76             .collect()
77     }
78 
get_key(&self) -> PrivateKey79     pub fn get_key(&self) -> PrivateKey {
80         PrivateKey(
81             rustls_pemfile::pkcs8_private_keys(&mut io::BufReader::new(self.bytes_for("end.key")))
82                 .unwrap()[0]
83                 .clone(),
84         )
85     }
86 
get_client_chain(&self) -> Vec<Certificate>87     pub fn get_client_chain(&self) -> Vec<Certificate> {
88         rustls_pemfile::certs(&mut io::BufReader::new(self.bytes_for("client.fullchain")))
89             .unwrap()
90             .iter()
91             .map(|v| Certificate(v.clone()))
92             .collect()
93     }
94 
get_client_key(&self) -> PrivateKey95     pub fn get_client_key(&self) -> PrivateKey {
96         PrivateKey(
97             rustls_pemfile::pkcs8_private_keys(&mut io::BufReader::new(
98                 self.bytes_for("client.key"),
99             ))
100             .unwrap()[0]
101                 .clone(),
102         )
103     }
104 }
105 
106 #[derive(Debug)]
107 pub enum ErrorFromPeer {
108     Client(Error),
109     Server(Error),
110 }
111 
112 pub struct MockClientVerifier {
113     pub verified: fn() -> Result<ClientCertVerified, Error>,
114     pub subjects: Option<DistinguishedNames>,
115     pub mandatory: Option<bool>,
116     pub offered_schemes: Option<Vec<SignatureScheme>>,
117 }
118 
119 impl ClientCertVerifier for MockClientVerifier {
client_auth_mandatory(&self) -> Option<bool>120     fn client_auth_mandatory(&self) -> Option<bool> {
121         self.mandatory
122     }
123 
client_auth_root_subjects(&self) -> Option<DistinguishedNames>124     fn client_auth_root_subjects(&self) -> Option<DistinguishedNames> {
125         self.subjects.as_ref().cloned()
126     }
127 
verify_client_cert( &self, _end_entity: &Certificate, _intermediates: &[Certificate], _now: std::time::SystemTime, ) -> Result<ClientCertVerified, Error>128     fn verify_client_cert(
129         &self,
130         _end_entity: &Certificate,
131         _intermediates: &[Certificate],
132         _now: std::time::SystemTime,
133     ) -> Result<ClientCertVerified, Error> {
134         (self.verified)()
135     }
136 
supported_verify_schemes(&self) -> Vec<SignatureScheme>137     fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
138         if let Some(schemes) = &self.offered_schemes {
139             schemes.clone()
140         } else {
141             WebPkiVerifier::verification_schemes()
142         }
143     }
144 }
145 
server_config_with_verifier( kt: KeyType, client_cert_verifier: MockClientVerifier, ) -> ServerConfig146 pub fn server_config_with_verifier(
147     kt: KeyType,
148     client_cert_verifier: MockClientVerifier,
149 ) -> ServerConfig {
150     ServerConfig::builder()
151         .with_safe_defaults()
152         .with_client_cert_verifier(Arc::new(client_cert_verifier))
153         .with_single_cert(kt.get_chain(), kt.get_key())
154         .unwrap()
155 }
156 
157 #[cfg(test)]
158 mod test {
159     use super::*;
160 
161     use std::convert::TryInto;
162     use std::io;
163     use std::ops::DerefMut;
164     use std::sync::Arc;
165 
166     use rustls::internal::msgs::base::PayloadU16;
167     use rustls::server::ClientCertVerified;
168     use rustls::{
169         ClientConfig, ClientConnection, ConnectionCommon, Error, RootCertStore, ServerConfig,
170         ServerConnection, SideData,
171     };
172 
assert_debug_eq<T>(err: T, expect: T) where T: std::fmt::Debug,173     fn assert_debug_eq<T>(err: T, expect: T)
174     where
175         T: std::fmt::Debug,
176     {
177         assert_eq!(format!("{err:?}"), format!("{expect:?}"));
178     }
179 
finish_client_config_with_creds( kt: KeyType, config: rustls::ConfigBuilder<ClientConfig, rustls::WantsVerifier>, ) -> ClientConfig180     fn finish_client_config_with_creds(
181         kt: KeyType,
182         config: rustls::ConfigBuilder<ClientConfig, rustls::WantsVerifier>,
183     ) -> ClientConfig {
184         let mut root_store = RootCertStore::empty();
185         let mut rootbuf = io::BufReader::new(kt.bytes_for("ca.cert"));
186         root_store.add_parsable_certificates(&rustls_pemfile::certs(&mut rootbuf).unwrap());
187 
188         config
189             .with_root_certificates(root_store)
190             .with_single_cert(kt.get_client_chain(), kt.get_client_key())
191             .unwrap()
192     }
193 
make_client_config_with_versions_with_auth( kt: KeyType, versions: &[&'static rustls::SupportedProtocolVersion], ) -> ClientConfig194     fn make_client_config_with_versions_with_auth(
195         kt: KeyType,
196         versions: &[&'static rustls::SupportedProtocolVersion],
197     ) -> ClientConfig {
198         let builder = ClientConfig::builder()
199             .with_safe_default_cipher_suites()
200             .with_safe_default_kx_groups()
201             .with_protocol_versions(versions)
202             .unwrap();
203         finish_client_config_with_creds(kt, builder)
204     }
205 
get_client_root_store(kt: KeyType) -> RootCertStore206     fn get_client_root_store(kt: KeyType) -> RootCertStore {
207         let roots = kt.get_chain();
208         let mut client_auth_roots = RootCertStore::empty();
209         for root in roots {
210             client_auth_roots.add(&root).unwrap();
211         }
212         client_auth_roots
213     }
214 
do_handshake_until_error( client: &mut ClientConnection, server: &mut ServerConnection, ) -> Result<(), ErrorFromPeer>215     fn do_handshake_until_error(
216         client: &mut ClientConnection,
217         server: &mut ServerConnection,
218     ) -> Result<(), ErrorFromPeer> {
219         while server.is_handshaking() || client.is_handshaking() {
220             transfer(client, server);
221             server
222                 .process_new_packets()
223                 .map_err(ErrorFromPeer::Server)?;
224             transfer(server, client);
225             client
226                 .process_new_packets()
227                 .map_err(ErrorFromPeer::Client)?;
228         }
229 
230         Ok(())
231     }
232 
transfer( left: &mut impl DerefMut<Target = ConnectionCommon<impl SideData>>, right: &mut impl DerefMut<Target = ConnectionCommon<impl SideData>>, ) -> usize233     fn transfer(
234         left: &mut impl DerefMut<Target = ConnectionCommon<impl SideData>>,
235         right: &mut impl DerefMut<Target = ConnectionCommon<impl SideData>>,
236     ) -> usize {
237         let mut buf = [0u8; 262144];
238         let mut total = 0;
239 
240         while left.wants_write() {
241             let sz = {
242                 let into_buf: &mut dyn io::Write = &mut &mut buf[..];
243                 left.write_tls(into_buf).unwrap()
244             };
245             total += sz;
246             if sz == 0 {
247                 return total;
248             }
249 
250             let mut offs = 0;
251             loop {
252                 let from_buf: &mut dyn io::Read = &mut &buf[offs..sz];
253                 offs += right.read_tls(from_buf).unwrap();
254                 if sz == offs {
255                     break;
256                 }
257             }
258         }
259 
260         total
261     }
262 
dns_name(name: &'static str) -> rustls::ServerName263     fn dns_name(name: &'static str) -> rustls::ServerName {
264         name.try_into().unwrap()
265     }
266 
make_pair_for_arc_configs( client_config: &Arc<ClientConfig>, server_config: &Arc<ServerConfig>, ) -> (ClientConnection, ServerConnection)267     fn make_pair_for_arc_configs(
268         client_config: &Arc<ClientConfig>,
269         server_config: &Arc<ServerConfig>,
270     ) -> (ClientConnection, ServerConnection) {
271         (
272             ClientConnection::new(Arc::clone(client_config), dns_name("localhost")).unwrap(),
273             ServerConnection::new(Arc::clone(server_config)).unwrap(),
274         )
275     }
276 
ver_ok() -> Result<ClientCertVerified, Error>277     fn ver_ok() -> Result<ClientCertVerified, Error> {
278         Ok(rustls::server::ClientCertVerified::assertion())
279     }
280 
281     #[test]
282     // Happy path, we resolve to a root, it is verified OK, should be able to connect
client_verifier_works()283     fn client_verifier_works() {
284         for kt in ALL_KEY_TYPES.iter() {
285             let client_verifier = MockClientVerifier {
286                 verified: ver_ok,
287                 subjects: Some(
288                     get_client_root_store(*kt)
289                         .roots
290                         .iter()
291                         .map(|r| PayloadU16(r.subject().to_vec()))
292                         .collect(),
293                 ),
294                 mandatory: Some(true),
295                 offered_schemes: None,
296             };
297 
298             let server_config = server_config_with_verifier(*kt, client_verifier);
299             let server_config = Arc::new(server_config);
300 
301             for version in rustls::ALL_VERSIONS {
302                 let client_config = make_client_config_with_versions_with_auth(*kt, &[version]);
303                 let (mut client, mut server) =
304                     make_pair_for_arc_configs(&Arc::new(client_config.clone()), &server_config);
305                 let err = do_handshake_until_error(&mut client, &mut server);
306                 assert_debug_eq(err, Ok(()));
307             }
308         }
309     }
310 }
311