1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5 use alloc::vec::Vec;
6 use itertools::Itertools;
7
8 use crate::crypto::HpkeContextR;
9
10 use super::{
11 CipherSuiteProvider, CryptoProvider, HpkeCiphertext, HpkeContextS, HpkePublicKey, HpkeSecretKey,
12 };
13
14 const PATH: &str = concat!(
15 env!("CARGO_MANIFEST_DIR"),
16 "/test_data/crypto_provider.json"
17 );
18
19 #[cfg(any(target_arch = "wasm32", not(feature = "std")))]
20 const SERIALIZED_TEST_SUITES: &[u8] = include_bytes!(concat!(
21 env!("CARGO_MANIFEST_DIR"),
22 "/test_data/crypto_provider.json"
23 ));
24
25 pub use hpke_rfc_conformance::{
26 verify_hpke_context_tests, verify_hpke_encap_tests, EncapOutput, TestHpke,
27 };
28
29 pub const DATA_SIZES: [usize; 5] = [0, 1, 16, 123, 2000];
30
31 #[derive(serde::Serialize, serde::Deserialize, Default)]
32 struct TestSuite {
33 cipher_suite: u16,
34 #[serde(default)]
35 signature_tests: Vec<SignatureTestCase>,
36 #[serde(default)]
37 aead_tests: Vec<AeadTestCase>,
38 #[serde(default)]
39 hpke_tests: HpkeTestCases,
40 #[serde(default)]
41 hkdf_tests: Vec<HkdfTestCase>,
42 #[serde(default)]
43 mac_tests: Vec<MacTestCase>,
44 #[serde(default)]
45 hash_tests: Vec<HashTestCase>,
46 }
47
48 #[cfg(all(not(mls_build_async), not(target_arch = "wasm32"), feature = "std"))]
49 #[cfg_attr(coverage_nightly, coverage(off))]
generate_tests<C: CryptoProvider>(crypto: &C)50 pub fn generate_tests<C: CryptoProvider>(crypto: &C) {
51 for cs in crypto.supported_cipher_suites() {
52 crypto.cipher_suite_provider(cs).unwrap();
53 }
54
55 let mut test_suites = create_or_load_tests(crypto);
56
57 for test_suite in test_suites.iter_mut() {
58 let cs = test_suite.cipher_suite.into();
59 let cs = crypto.cipher_suite_provider(cs).unwrap();
60
61 test_suite.signature_tests = generate_signature_tests(&cs);
62 test_suite.hpke_tests = generate_hpke_tests(&cs);
63 test_suite.hkdf_tests = generate_hkdf_tests(&cs);
64 }
65
66 std::fs::write(PATH, serde_json::to_string_pretty(&test_suites).unwrap()).unwrap();
67 }
68
69 #[cfg(all(not(mls_build_async), not(target_arch = "wasm32"), feature = "std"))]
70 #[cfg_attr(coverage_nightly, coverage(off))]
create_or_load_tests<C: CryptoProvider>(crypto: &C) -> Vec<TestSuite>71 fn create_or_load_tests<C: CryptoProvider>(crypto: &C) -> Vec<TestSuite> {
72 if std::path::Path::new(PATH).exists() {
73 serde_json::from_slice(&std::fs::read(PATH).unwrap()).unwrap()
74 } else {
75 crypto
76 .supported_cipher_suites()
77 .into_iter()
78 .map(|cipher_suite| TestSuite {
79 cipher_suite: cipher_suite.into(),
80 ..Default::default()
81 })
82 .collect()
83 }
84 }
85
86 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_tests<C: CryptoProvider>(crypto: &C, signature_secret_key_compatible: bool)87 pub async fn verify_tests<C: CryptoProvider>(crypto: &C, signature_secret_key_compatible: bool) {
88 #[cfg(any(target_arch = "wasm32", not(feature = "std")))]
89 let test_suites: Vec<TestSuite> = serde_json::from_slice(SERIALIZED_TEST_SUITES).unwrap();
90
91 #[cfg(all(not(target_arch = "wasm32"), feature = "std"))]
92 let test_suites: Vec<TestSuite> =
93 serde_json::from_slice(&std::fs::read(PATH).unwrap()).unwrap();
94
95 for test_suite in test_suites {
96 let test_cs = test_suite.cipher_suite.into();
97
98 let Some(cs) = crypto.cipher_suite_provider(test_cs) else {
99 continue;
100 };
101
102 assert_eq!(cs.cipher_suite(), test_cs);
103
104 verify_hkdf_tests(&cs, test_suite.hkdf_tests).await;
105 verify_aead_tests(&cs, test_suite.aead_tests).await;
106 verify_mac_tests(&cs, test_suite.mac_tests).await;
107 verify_hpke_tests(&cs, test_suite.hpke_tests).await;
108
109 verify_signature_tests(
110 &cs,
111 test_suite.signature_tests,
112 signature_secret_key_compatible,
113 )
114 .await;
115
116 verify_hash_tests(&cs, test_suite.hash_tests).await;
117 }
118 }
119
120 #[derive(serde::Serialize, serde::Deserialize)]
121 struct SignatureTestCase {
122 #[serde(with = "hex::serde")]
123 secret: Vec<u8>,
124 #[serde(with = "hex::serde")]
125 public: Vec<u8>,
126 #[serde(with = "hex::serde")]
127 data: Vec<u8>,
128 #[serde(with = "hex::serde")]
129 signature: Vec<u8>,
130 }
131
132 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_signature_tests<C: CipherSuiteProvider>( cs: &C, test_cases: Vec<SignatureTestCase>, secret_key_compatible: bool, )133 async fn verify_signature_tests<C: CipherSuiteProvider>(
134 cs: &C,
135 test_cases: Vec<SignatureTestCase>,
136 secret_key_compatible: bool,
137 ) {
138 // Checks that `cs` can sign and verify
139 let generated = generate_signature_tests(cs).await;
140
141 for (test_case, is_generated) in test_cases
142 .into_iter()
143 .map(|tc| (tc, false))
144 .chain(generated.into_iter().map(|tc| (tc, true)))
145 {
146 let public = test_case.public.into();
147
148 // Checks that `cs` can verify signatures generated by itself and another implementation
149 cs.verify(&public, &test_case.signature, &test_case.data)
150 .await
151 .unwrap();
152
153 if is_generated || secret_key_compatible {
154 let secret = test_case.secret.into();
155
156 let derived = cs.signature_key_derive_public(&secret).await.unwrap();
157
158 cs.sign(&secret, b"hello world").await.unwrap();
159
160 assert_eq!(derived, public);
161 }
162 }
163 }
164
165 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
166 #[cfg_attr(coverage_nightly, coverage(off))]
generate_signature_tests<C: CipherSuiteProvider>(cs: &C) -> Vec<SignatureTestCase>167 async fn generate_signature_tests<C: CipherSuiteProvider>(cs: &C) -> Vec<SignatureTestCase> {
168 let mut tests = Vec::new();
169
170 for data_size in DATA_SIZES {
171 let data = cs.random_bytes_vec(data_size).unwrap();
172 let (secret, public) = cs.signature_key_generate().await.unwrap();
173 let signature = cs.sign(&secret, &data).await.unwrap();
174
175 tests.push(SignatureTestCase {
176 secret: secret.to_vec(),
177 public: public.to_vec(),
178 data,
179 signature,
180 });
181 }
182
183 tests
184 }
185
186 // Test vectors from the RFC
187 #[derive(serde::Deserialize, serde::Serialize)]
188 struct AeadTestCase {
189 #[serde(with = "hex::serde")]
190 pub key: Vec<u8>,
191 #[serde(with = "hex::serde")]
192 pub iv: Vec<u8>,
193 #[serde(with = "hex::serde")]
194 pub ct: Vec<u8>,
195 #[serde(with = "hex::serde")]
196 pub aad: Vec<u8>,
197 #[serde(with = "hex::serde")]
198 pub pt: Vec<u8>,
199 }
200
201 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_aead_tests<C: CipherSuiteProvider>(cs: &C, test_cases: Vec<AeadTestCase>)202 async fn verify_aead_tests<C: CipherSuiteProvider>(cs: &C, test_cases: Vec<AeadTestCase>) {
203 for case in test_cases {
204 let ciphertext = cs
205 .aead_seal(&case.key, &case.pt, Some(&case.aad), &case.iv)
206 .await
207 .unwrap();
208
209 assert_eq!(ciphertext, case.ct);
210
211 let plaintext = cs
212 .aead_open(&case.key, &ciphertext, Some(&case.aad), &case.iv)
213 .await
214 .unwrap();
215
216 assert_eq!(plaintext.to_vec(), case.pt);
217 }
218 }
219
220 #[derive(serde::Serialize, serde::Deserialize, Default)]
221 struct HpkeTestCases {
222 #[serde(with = "hex::serde")]
223 ikm: Vec<u8>,
224 #[serde(with = "hex::serde")]
225 secret: Vec<u8>,
226 #[serde(with = "hex::serde")]
227 public: Vec<u8>,
228
229 seal_tests: Vec<HpkeSealTestCase>,
230 export_tests: Vec<HpkeExportTestCase>,
231 }
232
233 #[derive(serde::Serialize, serde::Deserialize)]
234 struct HpkeSealTestCase {
235 #[serde(with = "hex::serde")]
236 plaintext: Vec<u8>,
237 #[serde(with = "hex::serde")]
238 info: Vec<u8>,
239 #[serde(with = "hex::serde")]
240 aad: Vec<u8>,
241
242 // Seal and open
243 #[serde(with = "hex::serde")]
244 sealed_kem_output: Vec<u8>,
245 #[serde(with = "hex::serde")]
246 sealed_ciphertext: Vec<u8>,
247
248 // Setup s and r
249 #[serde(with = "hex::serde")]
250 setup_s_kem_output: Vec<u8>,
251 #[serde(with = "hex::serde")]
252 setup_s_ciphertext: Vec<u8>,
253 }
254
255 #[derive(serde::Serialize, serde::Deserialize)]
256 struct HpkeExportTestCase {
257 #[serde(with = "hex::serde")]
258 info: Vec<u8>,
259 #[serde(with = "hex::serde")]
260 kem_output: Vec<u8>,
261
262 #[serde(with = "hex::serde")]
263 exporter_context: Vec<u8>,
264 exported_len: usize,
265 #[serde(with = "hex::serde")]
266 exported: Vec<u8>,
267 }
268
269 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_hpke_tests<C: CipherSuiteProvider>(cs: &C, test_cases: HpkeTestCases)270 async fn verify_hpke_tests<C: CipherSuiteProvider>(cs: &C, test_cases: HpkeTestCases) {
271 let generated = generate_hpke_tests(cs).await;
272 verify_hpke_test(cs, generated).await;
273 verify_hpke_test(cs, test_cases).await;
274 }
275
276 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_hpke_test<C: CipherSuiteProvider>(cs: &C, test_cases: HpkeTestCases)277 async fn verify_hpke_test<C: CipherSuiteProvider>(cs: &C, test_cases: HpkeTestCases) {
278 let (secret, public) = cs.kem_derive(&test_cases.ikm).await.unwrap();
279
280 assert_eq!(&secret, &test_cases.secret.into());
281 assert_eq!(&public, &test_cases.public.into());
282
283 for test in test_cases.seal_tests {
284 let ct = HpkeCiphertext {
285 kem_output: test.sealed_kem_output.clone(),
286 ciphertext: test.sealed_ciphertext.clone(),
287 };
288
289 test_open_ciphertext(cs, &secret, &public, &ct, &test).await;
290
291 let ct = HpkeCiphertext {
292 kem_output: test.setup_s_kem_output.clone(),
293 ciphertext: test.setup_s_ciphertext.clone(),
294 };
295
296 test_open_ciphertext(cs, &secret, &public, &ct, &test).await;
297 }
298
299 for test in test_cases.export_tests {
300 let context_r = cs
301 .hpke_setup_r(&test.kem_output, &secret, &public, &test.info)
302 .await
303 .unwrap();
304
305 let exported = context_r
306 .export(&test.exporter_context, test.exported_len)
307 .await
308 .unwrap();
309
310 assert_eq!(exported, test.exported);
311 }
312 }
313
314 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_open_ciphertext<C: CipherSuiteProvider>( cs: &C, secret: &HpkeSecretKey, public: &HpkePublicKey, ct: &HpkeCiphertext, test: &HpkeSealTestCase, )315 async fn test_open_ciphertext<C: CipherSuiteProvider>(
316 cs: &C,
317 secret: &HpkeSecretKey,
318 public: &HpkePublicKey,
319 ct: &HpkeCiphertext,
320 test: &HpkeSealTestCase,
321 ) {
322 let aad = (!test.aad.is_empty()).then_some(test.aad.as_slice());
323
324 let opened = cs
325 .hpke_open(ct, secret, public, &test.info, aad)
326 .await
327 .unwrap();
328
329 assert_eq!(&opened, &test.plaintext);
330
331 let mut context_r = cs
332 .hpke_setup_r(&ct.kem_output, secret, public, &test.info)
333 .await
334 .unwrap();
335
336 let opened = context_r.open(aad, &ct.ciphertext).await.unwrap();
337 assert_eq!(&opened, &test.plaintext);
338 }
339
340 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
341 #[cfg_attr(coverage_nightly, coverage(off))]
generate_hpke_tests<C: CipherSuiteProvider>(cs: &C) -> HpkeTestCases342 async fn generate_hpke_tests<C: CipherSuiteProvider>(cs: &C) -> HpkeTestCases {
343 let ikm = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap();
344 let (secret, public) = cs.kem_derive(&ikm).await.unwrap();
345
346 let sizes_iter = DATA_SIZES.iter().copied();
347
348 let mut seal_tests = Vec::new();
349
350 for ((pt_size, info_size), aad_size) in sizes_iter
351 .clone()
352 .skip(1)
353 .cartesian_product(sizes_iter.clone())
354 .cartesian_product(sizes_iter.clone())
355 {
356 let plaintext = cs.random_bytes_vec(pt_size).unwrap();
357 let info = cs.random_bytes_vec(info_size).unwrap();
358 let aad = cs.random_bytes_vec(aad_size).unwrap();
359
360 let sealed = cs
361 .hpke_seal(&public, &info, (aad_size > 0).then_some(&aad), &plaintext)
362 .await
363 .unwrap();
364
365 let (setup_s_kem_output, mut context_s) = cs.hpke_setup_s(&public, &info).await.unwrap();
366
367 let setup_s_ciphertext = context_s
368 .seal((aad_size > 0).then_some(&aad), &plaintext)
369 .await
370 .unwrap();
371
372 seal_tests.push(HpkeSealTestCase {
373 plaintext,
374 info,
375 aad,
376 sealed_kem_output: sealed.kem_output,
377 sealed_ciphertext: sealed.ciphertext,
378 setup_s_kem_output,
379 setup_s_ciphertext,
380 })
381 }
382
383 let mut export_tests = Vec::new();
384
385 for ((context_len, exported_len), info_size) in sizes_iter
386 .clone()
387 .cartesian_product(sizes_iter.clone().skip(1))
388 .cartesian_product(sizes_iter)
389 {
390 let exporter_context = cs.random_bytes_vec(context_len).unwrap();
391 let info = cs.random_bytes_vec(info_size).unwrap();
392 let (kem_output, context) = cs.hpke_setup_s(&public, &info).await.unwrap();
393
394 let exported = context
395 .export(&exporter_context, exported_len)
396 .await
397 .unwrap();
398
399 export_tests.push(HpkeExportTestCase {
400 info,
401 kem_output,
402 exporter_context,
403 exported_len,
404 exported,
405 });
406 }
407
408 HpkeTestCases {
409 ikm,
410 secret: secret.to_vec(),
411 public: public.to_vec(),
412 seal_tests,
413 export_tests,
414 }
415 }
416
417 #[derive(serde::Deserialize, serde::Serialize)]
418 struct HkdfTestCase {
419 #[serde(with = "hex::serde")]
420 pub ikm: Vec<u8>,
421 #[serde(with = "hex::serde")]
422 pub salt: Vec<u8>,
423 #[serde(with = "hex::serde")]
424 pub info: Vec<u8>,
425 pub len: usize,
426 #[serde(with = "hex::serde")]
427 pub prk: Vec<u8>,
428 #[serde(with = "hex::serde")]
429 pub okm: Vec<u8>,
430 }
431
432 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_hkdf_tests<C: CipherSuiteProvider>(cs: &C, test_cases: Vec<HkdfTestCase>)433 async fn verify_hkdf_tests<C: CipherSuiteProvider>(cs: &C, test_cases: Vec<HkdfTestCase>) {
434 for case in test_cases {
435 let extracted = cs.kdf_extract(&case.salt, &case.ikm).await.unwrap();
436
437 assert_eq!(extracted.to_vec(), case.prk);
438
439 let expanded = cs
440 .kdf_expand(&case.prk, &case.info, case.len)
441 .await
442 .unwrap();
443
444 assert_eq!(expanded.to_vec(), case.okm);
445 }
446 }
447
448 #[cfg(all(not(mls_build_async), not(target_arch = "wasm32"), feature = "std"))]
449 #[cfg_attr(coverage_nightly, coverage(off))]
generate_hkdf_tests<C: CipherSuiteProvider>(cs: &C) -> Vec<HkdfTestCase>450 fn generate_hkdf_tests<C: CipherSuiteProvider>(cs: &C) -> Vec<HkdfTestCase> {
451 let iter = DATA_SIZES.iter().copied();
452
453 let iter = iter
454 .clone()
455 .skip(1)
456 .cartesian_product(iter.clone())
457 .cartesian_product(iter.clone())
458 .cartesian_product(iter.skip(1));
459
460 iter.map(|(((ikm_size, salt_size), info_size), len)| {
461 let ikm = cs.random_bytes_vec(ikm_size).unwrap();
462 let salt = cs.random_bytes_vec(salt_size).unwrap();
463 let info = cs.random_bytes_vec(info_size).unwrap();
464
465 let prk = cs.kdf_extract(&salt, &ikm).unwrap().to_vec();
466 let okm = cs.kdf_expand(&prk, &info, len).unwrap().to_vec();
467
468 HkdfTestCase {
469 ikm,
470 salt,
471 info,
472 len,
473 prk,
474 okm,
475 }
476 })
477 .collect()
478 }
479
480 // Test vectors from RFC 4231
481 #[derive(serde::Deserialize, serde::Serialize)]
482 struct MacTestCase {
483 #[serde(with = "hex::serde")]
484 key: Vec<u8>,
485 #[serde(with = "hex::serde")]
486 data: Vec<u8>,
487 #[serde(with = "hex::serde")]
488 tag: Vec<u8>,
489 }
490
491 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_mac_tests<C: CipherSuiteProvider>(cs: &C, test_cases: Vec<MacTestCase>)492 async fn verify_mac_tests<C: CipherSuiteProvider>(cs: &C, test_cases: Vec<MacTestCase>) {
493 for case in test_cases {
494 let computed = cs.mac(&case.key, &case.data).await.unwrap();
495 assert_eq!(computed, case.tag);
496 }
497 }
498
499 #[derive(serde::Deserialize, serde::Serialize)]
500 struct HashTestCase {
501 #[serde(with = "hex::serde")]
502 input: Vec<u8>,
503 #[serde(with = "hex::serde")]
504 output: Vec<u8>,
505 }
506
507 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_hash_tests<C: CipherSuiteProvider>(cs: &C, test_cases: Vec<HashTestCase>)508 async fn verify_hash_tests<C: CipherSuiteProvider>(cs: &C, test_cases: Vec<HashTestCase>) {
509 for case in test_cases {
510 let computed = cs.hash(&case.input).await.unwrap();
511 assert_eq!(computed, case.output);
512 }
513 }
514
515 mod hpke_rfc_conformance {
516 use alloc::vec::Vec;
517
518 use crate::crypto::{CipherSuite, HpkeContextR, HpkeContextS, HpkeModeId};
519
520 #[derive(serde::Deserialize, Debug, Clone)]
521 pub struct TestCaseAlgo {
522 pub kem_id: u16,
523 pub kdf_id: u16,
524 pub aead_id: u16,
525 pub mode: u8,
526 }
527
528 impl TestCaseAlgo {
cipher_suite(&self) -> Option<CipherSuite>529 fn cipher_suite(&self) -> Option<CipherSuite> {
530 if ![HpkeModeId::Base as u8, HpkeModeId::Psk as u8].contains(&self.mode) {
531 return None;
532 }
533
534 match (self.kem_id, self.kdf_id, self.aead_id) {
535 (0x0010, 0x0001, 0x0001) => Some(CipherSuite::P256_AES128),
536 (0x0011, 0x0002, 0x0002) => Some(CipherSuite::P384_AES256),
537 (0x0012, 0x0003, 0x0002) => Some(CipherSuite::P521_AES256),
538 (0x0020, 0x0001, 0x0001) => Some(CipherSuite::CURVE25519_AES128),
539 (0x0020, 0x0001, 0x0003) => Some(CipherSuite::CURVE25519_CHACHA),
540 (0x0021, 0x0003, 0x0002) => Some(CipherSuite::CURVE448_AES256),
541 (0x0021, 0x0003, 0x0003) => Some(CipherSuite::CURVE448_CHACHA),
542 _ => None,
543 }
544 }
545 }
546
547 #[derive(serde::Deserialize, Debug)]
548 struct TestCase {
549 #[serde(flatten)]
550 algo: TestCaseAlgo,
551 #[serde(with = "hex::serde", rename(deserialize = "pkRm"))]
552 pk_rm: Vec<u8>,
553 #[serde(with = "hex::serde", rename(deserialize = "skRm"))]
554 sk_rm: Vec<u8>,
555 #[serde(with = "hex::serde", rename(deserialize = "ikmE"))]
556 ikm_e: Vec<u8>,
557 #[serde(with = "hex::serde")]
558 shared_secret: Vec<u8>,
559 #[serde(with = "hex::serde")]
560 enc: Vec<u8>,
561 #[serde(with = "hex::serde")]
562 exporter_secret: Vec<u8>,
563 #[serde(with = "hex::serde")]
564 base_nonce: Vec<u8>,
565 #[serde(with = "hex::serde")]
566 key: Vec<u8>,
567 encryptions: Vec<EncryptionTestCase>,
568 exports: Vec<ExportTestCase>,
569 }
570
571 #[derive(serde::Deserialize, Debug)]
572 struct EncryptionTestCase {
573 #[serde(with = "hex::serde", rename = "pt")]
574 plaintext: Vec<u8>,
575 #[serde(with = "hex::serde")]
576 aad: Vec<u8>,
577 #[serde(with = "hex::serde", rename = "ct")]
578 ciphertext: Vec<u8>,
579 }
580
581 #[derive(serde::Deserialize, Debug)]
582 struct ExportTestCase {
583 #[serde(with = "hex::serde")]
584 exporter_context: Vec<u8>,
585 #[serde(rename = "L")]
586 length: usize,
587 #[serde(with = "hex::serde")]
588 exported_value: Vec<u8>,
589 }
590
591 #[cfg(any(target_arch = "wasm32", not(feature = "std")))]
get_test_cases() -> Vec<TestCase>592 fn get_test_cases() -> Vec<TestCase> {
593 let bytes = include_bytes!(concat!(
594 env!("CARGO_MANIFEST_DIR"),
595 "/test_data/test_hpke.json"
596 ));
597
598 serde_json::from_slice(bytes).unwrap()
599 }
600
601 #[cfg(all(not(target_arch = "wasm32"), feature = "std"))]
get_test_cases() -> Vec<TestCase>602 fn get_test_cases() -> Vec<TestCase> {
603 let path = concat!(env!("CARGO_MANIFEST_DIR"), "/test_data/test_hpke.json");
604
605 serde_json::from_slice(&std::fs::read(path).unwrap()).unwrap()
606 }
607
608 pub struct EncapOutput {
609 pub enc: Vec<u8>,
610 pub shared_secret: Vec<u8>,
611 }
612
613 impl EncapOutput {
new(enc: Vec<u8>, shared_secret: Vec<u8>) -> Self614 pub fn new(enc: Vec<u8>, shared_secret: Vec<u8>) -> Self {
615 Self { enc, shared_secret }
616 }
617 }
618
619 pub trait TestHpke {
620 type ContextS: HpkeContextS;
621 type ContextR: HpkeContextR;
622
hpke_context( &self, key: Vec<u8>, base_nonce: Vec<u8>, exporter_secret: Vec<u8>, ) -> (Self::ContextS, Self::ContextR)623 fn hpke_context(
624 &self,
625 key: Vec<u8>,
626 base_nonce: Vec<u8>,
627 exporter_secret: Vec<u8>,
628 ) -> (Self::ContextS, Self::ContextR);
629
encap(&mut self, ikm_e: Vec<u8>, pk_rm: Vec<u8>) -> EncapOutput630 fn encap(&mut self, ikm_e: Vec<u8>, pk_rm: Vec<u8>) -> EncapOutput;
decap(&mut self, enc: Vec<u8>, sk_rm: Vec<u8>, pk_rm: Vec<u8>) -> Vec<u8>631 fn decap(&mut self, enc: Vec<u8>, sk_rm: Vec<u8>, pk_rm: Vec<u8>) -> Vec<u8>;
632 }
633
634 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_hpke_context_tests<C: TestHpke>(hpke: &C, cipher_suite: CipherSuite)635 pub async fn verify_hpke_context_tests<C: TestHpke>(hpke: &C, cipher_suite: CipherSuite) {
636 for test_case in get_test_cases()
637 .into_iter()
638 .filter(|tc| matches!(tc.algo.cipher_suite(), Some(c) if c == cipher_suite))
639 {
640 let (mut context_s, mut context_r) = hpke.hpke_context(
641 test_case.key,
642 test_case.base_nonce,
643 test_case.exporter_secret,
644 );
645
646 for enc_test_case in test_case.encryptions {
647 // Encrypt
648 let ct = context_s
649 .seal(Some(&enc_test_case.aad), &enc_test_case.plaintext)
650 .await
651 .unwrap();
652
653 assert_eq!(ct, enc_test_case.ciphertext);
654
655 // Decrypt
656 let pt = context_r.open(Some(&enc_test_case.aad), &ct).await.unwrap();
657
658 assert_eq!(pt, enc_test_case.plaintext);
659 }
660
661 for test in test_case.exports {
662 let exported_s = context_s.export(&test.exporter_context, test.length).await;
663 assert_eq!(exported_s.unwrap(), test.exported_value);
664
665 let exported_r = context_r.export(&test.exporter_context, test.length).await;
666 assert_eq!(exported_r.unwrap(), test.exported_value);
667 }
668 }
669 }
670
671 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_hpke_encap_tests<C: TestHpke>(hpke: &mut C, cipher_suite: CipherSuite)672 pub async fn verify_hpke_encap_tests<C: TestHpke>(hpke: &mut C, cipher_suite: CipherSuite) {
673 for test_case in get_test_cases()
674 .into_iter()
675 .filter(|tc| matches!(tc.algo.cipher_suite(), Some(c) if c == cipher_suite))
676 {
677 let out = hpke.encap(test_case.ikm_e, test_case.pk_rm.clone());
678
679 assert_eq!(&out.enc, &test_case.enc);
680 assert_eq!(&out.shared_secret, &test_case.shared_secret);
681
682 let shared_secret = hpke.decap(test_case.enc, test_case.sk_rm, test_case.pk_rm);
683
684 assert_eq!(shared_secret, test_case.shared_secret);
685 }
686 }
687 }
688