1 //! A re-implementation of https://github.com/rustls/rustls/blob/v/0.20.8/rustls/tests/client_cert_verifier.rs
2
3 use std::io;
4 use std::sync::Arc;
5
6 use rustls::client::WebPkiVerifier;
7 use rustls::server::{ClientCertVerified, ClientCertVerifier};
8 use rustls::PrivateKey;
9 use rustls::{Certificate, DistinguishedNames, Error, ServerConfig, SignatureScheme};
10
11 macro_rules! embed_files {
12 (
13 $(
14 ($name:ident, $keytype:expr, $path:expr);
15 )+
16 ) => {
17 $(
18 const $name: &'static [u8] = include_bytes!(
19 concat!(env!("CARGO_MANIFEST_DIR"), "/data/ca/", $keytype, "/", $path));
20 )+
21
22 pub fn bytes_for(keytype: &str, path: &str) -> &'static [u8] {
23 match (keytype, path) {
24 $(
25 ($keytype, $path) => $name,
26 )+
27 _ => panic!("unknown keytype {} with path {}", keytype, path),
28 }
29 }
30 }
31 }
32
33 embed_files! {
34 (ECDSA_CA_CERT, "ecdsa", "ca.cert");
35 (ECDSA_CLIENT_FULLCHAIN, "ecdsa", "client.fullchain");
36 (ECDSA_CLIENT_KEY, "ecdsa", "client.key");
37 (ECDSA_END_FULLCHAIN, "ecdsa", "end.fullchain");
38 (ECDSA_END_KEY, "ecdsa", "end.key");
39
40 (EDDSA_CA_CERT, "eddsa", "ca.cert");
41 (EDDSA_CLIENT_FULLCHAIN, "eddsa", "client.fullchain");
42 (EDDSA_CLIENT_KEY, "eddsa", "client.key");
43 (EDDSA_END_FULLCHAIN, "eddsa", "end.fullchain");
44 (EDDSA_END_KEY, "eddsa", "end.key");
45
46 (RSA_CA_CERT, "rsa", "ca.cert");
47 (RSA_CLIENT_FULLCHAIN, "rsa", "client.fullchain");
48 (RSA_CLIENT_KEY, "rsa", "client.key");
49 (RSA_END_FULLCHAIN, "rsa", "end.fullchain");
50 (RSA_END_KEY, "rsa", "end.key");
51 }
52
53 #[derive(Clone, Copy, PartialEq)]
54 pub enum KeyType {
55 Rsa,
56 Ecdsa,
57 Ed25519,
58 }
59
60 pub static ALL_KEY_TYPES: [KeyType; 3] = [KeyType::Rsa, KeyType::Ecdsa, KeyType::Ed25519];
61
62 impl KeyType {
bytes_for(&self, part: &str) -> &'static [u8]63 fn bytes_for(&self, part: &str) -> &'static [u8] {
64 match self {
65 Self::Rsa => bytes_for("rsa", part),
66 Self::Ecdsa => bytes_for("ecdsa", part),
67 Self::Ed25519 => bytes_for("eddsa", part),
68 }
69 }
70
get_chain(&self) -> Vec<Certificate>71 pub fn get_chain(&self) -> Vec<Certificate> {
72 rustls_pemfile::certs(&mut io::BufReader::new(self.bytes_for("end.fullchain")))
73 .unwrap()
74 .iter()
75 .map(|v| Certificate(v.clone()))
76 .collect()
77 }
78
get_key(&self) -> PrivateKey79 pub fn get_key(&self) -> PrivateKey {
80 PrivateKey(
81 rustls_pemfile::pkcs8_private_keys(&mut io::BufReader::new(self.bytes_for("end.key")))
82 .unwrap()[0]
83 .clone(),
84 )
85 }
86
get_client_chain(&self) -> Vec<Certificate>87 pub fn get_client_chain(&self) -> Vec<Certificate> {
88 rustls_pemfile::certs(&mut io::BufReader::new(self.bytes_for("client.fullchain")))
89 .unwrap()
90 .iter()
91 .map(|v| Certificate(v.clone()))
92 .collect()
93 }
94
get_client_key(&self) -> PrivateKey95 pub fn get_client_key(&self) -> PrivateKey {
96 PrivateKey(
97 rustls_pemfile::pkcs8_private_keys(&mut io::BufReader::new(
98 self.bytes_for("client.key"),
99 ))
100 .unwrap()[0]
101 .clone(),
102 )
103 }
104 }
105
106 #[derive(Debug)]
107 pub enum ErrorFromPeer {
108 Client(Error),
109 Server(Error),
110 }
111
112 pub struct MockClientVerifier {
113 pub verified: fn() -> Result<ClientCertVerified, Error>,
114 pub subjects: Option<DistinguishedNames>,
115 pub mandatory: Option<bool>,
116 pub offered_schemes: Option<Vec<SignatureScheme>>,
117 }
118
119 impl ClientCertVerifier for MockClientVerifier {
client_auth_mandatory(&self) -> Option<bool>120 fn client_auth_mandatory(&self) -> Option<bool> {
121 self.mandatory
122 }
123
client_auth_root_subjects(&self) -> Option<DistinguishedNames>124 fn client_auth_root_subjects(&self) -> Option<DistinguishedNames> {
125 self.subjects.as_ref().cloned()
126 }
127
verify_client_cert( &self, _end_entity: &Certificate, _intermediates: &[Certificate], _now: std::time::SystemTime, ) -> Result<ClientCertVerified, Error>128 fn verify_client_cert(
129 &self,
130 _end_entity: &Certificate,
131 _intermediates: &[Certificate],
132 _now: std::time::SystemTime,
133 ) -> Result<ClientCertVerified, Error> {
134 (self.verified)()
135 }
136
supported_verify_schemes(&self) -> Vec<SignatureScheme>137 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
138 if let Some(schemes) = &self.offered_schemes {
139 schemes.clone()
140 } else {
141 WebPkiVerifier::verification_schemes()
142 }
143 }
144 }
145
server_config_with_verifier( kt: KeyType, client_cert_verifier: MockClientVerifier, ) -> ServerConfig146 pub fn server_config_with_verifier(
147 kt: KeyType,
148 client_cert_verifier: MockClientVerifier,
149 ) -> ServerConfig {
150 ServerConfig::builder()
151 .with_safe_defaults()
152 .with_client_cert_verifier(Arc::new(client_cert_verifier))
153 .with_single_cert(kt.get_chain(), kt.get_key())
154 .unwrap()
155 }
156
157 #[cfg(test)]
158 mod test {
159 use super::*;
160
161 use std::convert::TryInto;
162 use std::io;
163 use std::ops::DerefMut;
164 use std::sync::Arc;
165
166 use rustls::internal::msgs::base::PayloadU16;
167 use rustls::server::ClientCertVerified;
168 use rustls::{
169 ClientConfig, ClientConnection, ConnectionCommon, Error, RootCertStore, ServerConfig,
170 ServerConnection, SideData,
171 };
172
assert_debug_eq<T>(err: T, expect: T) where T: std::fmt::Debug,173 fn assert_debug_eq<T>(err: T, expect: T)
174 where
175 T: std::fmt::Debug,
176 {
177 assert_eq!(format!("{err:?}"), format!("{expect:?}"));
178 }
179
finish_client_config_with_creds( kt: KeyType, config: rustls::ConfigBuilder<ClientConfig, rustls::WantsVerifier>, ) -> ClientConfig180 fn finish_client_config_with_creds(
181 kt: KeyType,
182 config: rustls::ConfigBuilder<ClientConfig, rustls::WantsVerifier>,
183 ) -> ClientConfig {
184 let mut root_store = RootCertStore::empty();
185 let mut rootbuf = io::BufReader::new(kt.bytes_for("ca.cert"));
186 root_store.add_parsable_certificates(&rustls_pemfile::certs(&mut rootbuf).unwrap());
187
188 config
189 .with_root_certificates(root_store)
190 .with_single_cert(kt.get_client_chain(), kt.get_client_key())
191 .unwrap()
192 }
193
make_client_config_with_versions_with_auth( kt: KeyType, versions: &[&'static rustls::SupportedProtocolVersion], ) -> ClientConfig194 fn make_client_config_with_versions_with_auth(
195 kt: KeyType,
196 versions: &[&'static rustls::SupportedProtocolVersion],
197 ) -> ClientConfig {
198 let builder = ClientConfig::builder()
199 .with_safe_default_cipher_suites()
200 .with_safe_default_kx_groups()
201 .with_protocol_versions(versions)
202 .unwrap();
203 finish_client_config_with_creds(kt, builder)
204 }
205
get_client_root_store(kt: KeyType) -> RootCertStore206 fn get_client_root_store(kt: KeyType) -> RootCertStore {
207 let roots = kt.get_chain();
208 let mut client_auth_roots = RootCertStore::empty();
209 for root in roots {
210 client_auth_roots.add(&root).unwrap();
211 }
212 client_auth_roots
213 }
214
do_handshake_until_error( client: &mut ClientConnection, server: &mut ServerConnection, ) -> Result<(), ErrorFromPeer>215 fn do_handshake_until_error(
216 client: &mut ClientConnection,
217 server: &mut ServerConnection,
218 ) -> Result<(), ErrorFromPeer> {
219 while server.is_handshaking() || client.is_handshaking() {
220 transfer(client, server);
221 server
222 .process_new_packets()
223 .map_err(ErrorFromPeer::Server)?;
224 transfer(server, client);
225 client
226 .process_new_packets()
227 .map_err(ErrorFromPeer::Client)?;
228 }
229
230 Ok(())
231 }
232
transfer( left: &mut impl DerefMut<Target = ConnectionCommon<impl SideData>>, right: &mut impl DerefMut<Target = ConnectionCommon<impl SideData>>, ) -> usize233 fn transfer(
234 left: &mut impl DerefMut<Target = ConnectionCommon<impl SideData>>,
235 right: &mut impl DerefMut<Target = ConnectionCommon<impl SideData>>,
236 ) -> usize {
237 let mut buf = [0u8; 262144];
238 let mut total = 0;
239
240 while left.wants_write() {
241 let sz = {
242 let into_buf: &mut dyn io::Write = &mut &mut buf[..];
243 left.write_tls(into_buf).unwrap()
244 };
245 total += sz;
246 if sz == 0 {
247 return total;
248 }
249
250 let mut offs = 0;
251 loop {
252 let from_buf: &mut dyn io::Read = &mut &buf[offs..sz];
253 offs += right.read_tls(from_buf).unwrap();
254 if sz == offs {
255 break;
256 }
257 }
258 }
259
260 total
261 }
262
dns_name(name: &'static str) -> rustls::ServerName263 fn dns_name(name: &'static str) -> rustls::ServerName {
264 name.try_into().unwrap()
265 }
266
make_pair_for_arc_configs( client_config: &Arc<ClientConfig>, server_config: &Arc<ServerConfig>, ) -> (ClientConnection, ServerConnection)267 fn make_pair_for_arc_configs(
268 client_config: &Arc<ClientConfig>,
269 server_config: &Arc<ServerConfig>,
270 ) -> (ClientConnection, ServerConnection) {
271 (
272 ClientConnection::new(Arc::clone(client_config), dns_name("localhost")).unwrap(),
273 ServerConnection::new(Arc::clone(server_config)).unwrap(),
274 )
275 }
276
ver_ok() -> Result<ClientCertVerified, Error>277 fn ver_ok() -> Result<ClientCertVerified, Error> {
278 Ok(rustls::server::ClientCertVerified::assertion())
279 }
280
281 #[test]
282 // Happy path, we resolve to a root, it is verified OK, should be able to connect
client_verifier_works()283 fn client_verifier_works() {
284 for kt in ALL_KEY_TYPES.iter() {
285 let client_verifier = MockClientVerifier {
286 verified: ver_ok,
287 subjects: Some(
288 get_client_root_store(*kt)
289 .roots
290 .iter()
291 .map(|r| PayloadU16(r.subject().to_vec()))
292 .collect(),
293 ),
294 mandatory: Some(true),
295 offered_schemes: None,
296 };
297
298 let server_config = server_config_with_verifier(*kt, client_verifier);
299 let server_config = Arc::new(server_config);
300
301 for version in rustls::ALL_VERSIONS {
302 let client_config = make_client_config_with_versions_with_auth(*kt, &[version]);
303 let (mut client, mut server) =
304 make_pair_for_arc_configs(&Arc::new(client_config.clone()), &server_config);
305 let err = do_handshake_until_error(&mut client, &mut server);
306 assert_debug_eq(err, Ok(()));
307 }
308 }
309 }
310 }
311