1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec::Vec;
6 use itertools::Itertools;
7 
8 use crate::crypto::HpkeContextR;
9 
10 use super::{
11     CipherSuiteProvider, CryptoProvider, HpkeCiphertext, HpkeContextS, HpkePublicKey, HpkeSecretKey,
12 };
13 
14 const PATH: &str = concat!(
15     env!("CARGO_MANIFEST_DIR"),
16     "/test_data/crypto_provider.json"
17 );
18 
19 #[cfg(any(target_arch = "wasm32", not(feature = "std")))]
20 const SERIALIZED_TEST_SUITES: &[u8] = include_bytes!(concat!(
21     env!("CARGO_MANIFEST_DIR"),
22     "/test_data/crypto_provider.json"
23 ));
24 
25 pub use hpke_rfc_conformance::{
26     verify_hpke_context_tests, verify_hpke_encap_tests, EncapOutput, TestHpke,
27 };
28 
29 pub const DATA_SIZES: [usize; 5] = [0, 1, 16, 123, 2000];
30 
31 #[derive(serde::Serialize, serde::Deserialize, Default)]
32 struct TestSuite {
33     cipher_suite: u16,
34     #[serde(default)]
35     signature_tests: Vec<SignatureTestCase>,
36     #[serde(default)]
37     aead_tests: Vec<AeadTestCase>,
38     #[serde(default)]
39     hpke_tests: HpkeTestCases,
40     #[serde(default)]
41     hkdf_tests: Vec<HkdfTestCase>,
42     #[serde(default)]
43     mac_tests: Vec<MacTestCase>,
44     #[serde(default)]
45     hash_tests: Vec<HashTestCase>,
46 }
47 
48 #[cfg(all(not(mls_build_async), not(target_arch = "wasm32"), feature = "std"))]
49 #[cfg_attr(coverage_nightly, coverage(off))]
generate_tests<C: CryptoProvider>(crypto: &C)50 pub fn generate_tests<C: CryptoProvider>(crypto: &C) {
51     for cs in crypto.supported_cipher_suites() {
52         crypto.cipher_suite_provider(cs).unwrap();
53     }
54 
55     let mut test_suites = create_or_load_tests(crypto);
56 
57     for test_suite in test_suites.iter_mut() {
58         let cs = test_suite.cipher_suite.into();
59         let cs = crypto.cipher_suite_provider(cs).unwrap();
60 
61         test_suite.signature_tests = generate_signature_tests(&cs);
62         test_suite.hpke_tests = generate_hpke_tests(&cs);
63         test_suite.hkdf_tests = generate_hkdf_tests(&cs);
64     }
65 
66     std::fs::write(PATH, serde_json::to_string_pretty(&test_suites).unwrap()).unwrap();
67 }
68 
69 #[cfg(all(not(mls_build_async), not(target_arch = "wasm32"), feature = "std"))]
70 #[cfg_attr(coverage_nightly, coverage(off))]
create_or_load_tests<C: CryptoProvider>(crypto: &C) -> Vec<TestSuite>71 fn create_or_load_tests<C: CryptoProvider>(crypto: &C) -> Vec<TestSuite> {
72     if std::path::Path::new(PATH).exists() {
73         serde_json::from_slice(&std::fs::read(PATH).unwrap()).unwrap()
74     } else {
75         crypto
76             .supported_cipher_suites()
77             .into_iter()
78             .map(|cipher_suite| TestSuite {
79                 cipher_suite: cipher_suite.into(),
80                 ..Default::default()
81             })
82             .collect()
83     }
84 }
85 
86 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_tests<C: CryptoProvider>(crypto: &C, signature_secret_key_compatible: bool)87 pub async fn verify_tests<C: CryptoProvider>(crypto: &C, signature_secret_key_compatible: bool) {
88     #[cfg(any(target_arch = "wasm32", not(feature = "std")))]
89     let test_suites: Vec<TestSuite> = serde_json::from_slice(SERIALIZED_TEST_SUITES).unwrap();
90 
91     #[cfg(all(not(target_arch = "wasm32"), feature = "std"))]
92     let test_suites: Vec<TestSuite> =
93         serde_json::from_slice(&std::fs::read(PATH).unwrap()).unwrap();
94 
95     for test_suite in test_suites {
96         let test_cs = test_suite.cipher_suite.into();
97 
98         let Some(cs) = crypto.cipher_suite_provider(test_cs) else {
99             continue;
100         };
101 
102         assert_eq!(cs.cipher_suite(), test_cs);
103 
104         verify_hkdf_tests(&cs, test_suite.hkdf_tests).await;
105         verify_aead_tests(&cs, test_suite.aead_tests).await;
106         verify_mac_tests(&cs, test_suite.mac_tests).await;
107         verify_hpke_tests(&cs, test_suite.hpke_tests).await;
108 
109         verify_signature_tests(
110             &cs,
111             test_suite.signature_tests,
112             signature_secret_key_compatible,
113         )
114         .await;
115 
116         verify_hash_tests(&cs, test_suite.hash_tests).await;
117     }
118 }
119 
120 #[derive(serde::Serialize, serde::Deserialize)]
121 struct SignatureTestCase {
122     #[serde(with = "hex::serde")]
123     secret: Vec<u8>,
124     #[serde(with = "hex::serde")]
125     public: Vec<u8>,
126     #[serde(with = "hex::serde")]
127     data: Vec<u8>,
128     #[serde(with = "hex::serde")]
129     signature: Vec<u8>,
130 }
131 
132 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_signature_tests<C: CipherSuiteProvider>( cs: &C, test_cases: Vec<SignatureTestCase>, secret_key_compatible: bool, )133 async fn verify_signature_tests<C: CipherSuiteProvider>(
134     cs: &C,
135     test_cases: Vec<SignatureTestCase>,
136     secret_key_compatible: bool,
137 ) {
138     // Checks that `cs` can sign and verify
139     let generated = generate_signature_tests(cs).await;
140 
141     for (test_case, is_generated) in test_cases
142         .into_iter()
143         .map(|tc| (tc, false))
144         .chain(generated.into_iter().map(|tc| (tc, true)))
145     {
146         let public = test_case.public.into();
147 
148         // Checks that `cs` can verify signatures generated by itself and another implementation
149         cs.verify(&public, &test_case.signature, &test_case.data)
150             .await
151             .unwrap();
152 
153         if is_generated || secret_key_compatible {
154             let secret = test_case.secret.into();
155 
156             let derived = cs.signature_key_derive_public(&secret).await.unwrap();
157 
158             cs.sign(&secret, b"hello world").await.unwrap();
159 
160             assert_eq!(derived, public);
161         }
162     }
163 }
164 
165 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
166 #[cfg_attr(coverage_nightly, coverage(off))]
generate_signature_tests<C: CipherSuiteProvider>(cs: &C) -> Vec<SignatureTestCase>167 async fn generate_signature_tests<C: CipherSuiteProvider>(cs: &C) -> Vec<SignatureTestCase> {
168     let mut tests = Vec::new();
169 
170     for data_size in DATA_SIZES {
171         let data = cs.random_bytes_vec(data_size).unwrap();
172         let (secret, public) = cs.signature_key_generate().await.unwrap();
173         let signature = cs.sign(&secret, &data).await.unwrap();
174 
175         tests.push(SignatureTestCase {
176             secret: secret.to_vec(),
177             public: public.to_vec(),
178             data,
179             signature,
180         });
181     }
182 
183     tests
184 }
185 
186 // Test vectors from the RFC
187 #[derive(serde::Deserialize, serde::Serialize)]
188 struct AeadTestCase {
189     #[serde(with = "hex::serde")]
190     pub key: Vec<u8>,
191     #[serde(with = "hex::serde")]
192     pub iv: Vec<u8>,
193     #[serde(with = "hex::serde")]
194     pub ct: Vec<u8>,
195     #[serde(with = "hex::serde")]
196     pub aad: Vec<u8>,
197     #[serde(with = "hex::serde")]
198     pub pt: Vec<u8>,
199 }
200 
201 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_aead_tests<C: CipherSuiteProvider>(cs: &C, test_cases: Vec<AeadTestCase>)202 async fn verify_aead_tests<C: CipherSuiteProvider>(cs: &C, test_cases: Vec<AeadTestCase>) {
203     for case in test_cases {
204         let ciphertext = cs
205             .aead_seal(&case.key, &case.pt, Some(&case.aad), &case.iv)
206             .await
207             .unwrap();
208 
209         assert_eq!(ciphertext, case.ct);
210 
211         let plaintext = cs
212             .aead_open(&case.key, &ciphertext, Some(&case.aad), &case.iv)
213             .await
214             .unwrap();
215 
216         assert_eq!(plaintext.to_vec(), case.pt);
217     }
218 }
219 
220 #[derive(serde::Serialize, serde::Deserialize, Default)]
221 struct HpkeTestCases {
222     #[serde(with = "hex::serde")]
223     ikm: Vec<u8>,
224     #[serde(with = "hex::serde")]
225     secret: Vec<u8>,
226     #[serde(with = "hex::serde")]
227     public: Vec<u8>,
228 
229     seal_tests: Vec<HpkeSealTestCase>,
230     export_tests: Vec<HpkeExportTestCase>,
231 }
232 
233 #[derive(serde::Serialize, serde::Deserialize)]
234 struct HpkeSealTestCase {
235     #[serde(with = "hex::serde")]
236     plaintext: Vec<u8>,
237     #[serde(with = "hex::serde")]
238     info: Vec<u8>,
239     #[serde(with = "hex::serde")]
240     aad: Vec<u8>,
241 
242     // Seal and open
243     #[serde(with = "hex::serde")]
244     sealed_kem_output: Vec<u8>,
245     #[serde(with = "hex::serde")]
246     sealed_ciphertext: Vec<u8>,
247 
248     // Setup s and r
249     #[serde(with = "hex::serde")]
250     setup_s_kem_output: Vec<u8>,
251     #[serde(with = "hex::serde")]
252     setup_s_ciphertext: Vec<u8>,
253 }
254 
255 #[derive(serde::Serialize, serde::Deserialize)]
256 struct HpkeExportTestCase {
257     #[serde(with = "hex::serde")]
258     info: Vec<u8>,
259     #[serde(with = "hex::serde")]
260     kem_output: Vec<u8>,
261 
262     #[serde(with = "hex::serde")]
263     exporter_context: Vec<u8>,
264     exported_len: usize,
265     #[serde(with = "hex::serde")]
266     exported: Vec<u8>,
267 }
268 
269 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_hpke_tests<C: CipherSuiteProvider>(cs: &C, test_cases: HpkeTestCases)270 async fn verify_hpke_tests<C: CipherSuiteProvider>(cs: &C, test_cases: HpkeTestCases) {
271     let generated = generate_hpke_tests(cs).await;
272     verify_hpke_test(cs, generated).await;
273     verify_hpke_test(cs, test_cases).await;
274 }
275 
276 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_hpke_test<C: CipherSuiteProvider>(cs: &C, test_cases: HpkeTestCases)277 async fn verify_hpke_test<C: CipherSuiteProvider>(cs: &C, test_cases: HpkeTestCases) {
278     let (secret, public) = cs.kem_derive(&test_cases.ikm).await.unwrap();
279 
280     assert_eq!(&secret, &test_cases.secret.into());
281     assert_eq!(&public, &test_cases.public.into());
282 
283     for test in test_cases.seal_tests {
284         let ct = HpkeCiphertext {
285             kem_output: test.sealed_kem_output.clone(),
286             ciphertext: test.sealed_ciphertext.clone(),
287         };
288 
289         test_open_ciphertext(cs, &secret, &public, &ct, &test).await;
290 
291         let ct = HpkeCiphertext {
292             kem_output: test.setup_s_kem_output.clone(),
293             ciphertext: test.setup_s_ciphertext.clone(),
294         };
295 
296         test_open_ciphertext(cs, &secret, &public, &ct, &test).await;
297     }
298 
299     for test in test_cases.export_tests {
300         let context_r = cs
301             .hpke_setup_r(&test.kem_output, &secret, &public, &test.info)
302             .await
303             .unwrap();
304 
305         let exported = context_r
306             .export(&test.exporter_context, test.exported_len)
307             .await
308             .unwrap();
309 
310         assert_eq!(exported, test.exported);
311     }
312 }
313 
314 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_open_ciphertext<C: CipherSuiteProvider>( cs: &C, secret: &HpkeSecretKey, public: &HpkePublicKey, ct: &HpkeCiphertext, test: &HpkeSealTestCase, )315 async fn test_open_ciphertext<C: CipherSuiteProvider>(
316     cs: &C,
317     secret: &HpkeSecretKey,
318     public: &HpkePublicKey,
319     ct: &HpkeCiphertext,
320     test: &HpkeSealTestCase,
321 ) {
322     let aad = (!test.aad.is_empty()).then_some(test.aad.as_slice());
323 
324     let opened = cs
325         .hpke_open(ct, secret, public, &test.info, aad)
326         .await
327         .unwrap();
328 
329     assert_eq!(&opened, &test.plaintext);
330 
331     let mut context_r = cs
332         .hpke_setup_r(&ct.kem_output, secret, public, &test.info)
333         .await
334         .unwrap();
335 
336     let opened = context_r.open(aad, &ct.ciphertext).await.unwrap();
337     assert_eq!(&opened, &test.plaintext);
338 }
339 
340 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
341 #[cfg_attr(coverage_nightly, coverage(off))]
generate_hpke_tests<C: CipherSuiteProvider>(cs: &C) -> HpkeTestCases342 async fn generate_hpke_tests<C: CipherSuiteProvider>(cs: &C) -> HpkeTestCases {
343     let ikm = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap();
344     let (secret, public) = cs.kem_derive(&ikm).await.unwrap();
345 
346     let sizes_iter = DATA_SIZES.iter().copied();
347 
348     let mut seal_tests = Vec::new();
349 
350     for ((pt_size, info_size), aad_size) in sizes_iter
351         .clone()
352         .skip(1)
353         .cartesian_product(sizes_iter.clone())
354         .cartesian_product(sizes_iter.clone())
355     {
356         let plaintext = cs.random_bytes_vec(pt_size).unwrap();
357         let info = cs.random_bytes_vec(info_size).unwrap();
358         let aad = cs.random_bytes_vec(aad_size).unwrap();
359 
360         let sealed = cs
361             .hpke_seal(&public, &info, (aad_size > 0).then_some(&aad), &plaintext)
362             .await
363             .unwrap();
364 
365         let (setup_s_kem_output, mut context_s) = cs.hpke_setup_s(&public, &info).await.unwrap();
366 
367         let setup_s_ciphertext = context_s
368             .seal((aad_size > 0).then_some(&aad), &plaintext)
369             .await
370             .unwrap();
371 
372         seal_tests.push(HpkeSealTestCase {
373             plaintext,
374             info,
375             aad,
376             sealed_kem_output: sealed.kem_output,
377             sealed_ciphertext: sealed.ciphertext,
378             setup_s_kem_output,
379             setup_s_ciphertext,
380         })
381     }
382 
383     let mut export_tests = Vec::new();
384 
385     for ((context_len, exported_len), info_size) in sizes_iter
386         .clone()
387         .cartesian_product(sizes_iter.clone().skip(1))
388         .cartesian_product(sizes_iter)
389     {
390         let exporter_context = cs.random_bytes_vec(context_len).unwrap();
391         let info = cs.random_bytes_vec(info_size).unwrap();
392         let (kem_output, context) = cs.hpke_setup_s(&public, &info).await.unwrap();
393 
394         let exported = context
395             .export(&exporter_context, exported_len)
396             .await
397             .unwrap();
398 
399         export_tests.push(HpkeExportTestCase {
400             info,
401             kem_output,
402             exporter_context,
403             exported_len,
404             exported,
405         });
406     }
407 
408     HpkeTestCases {
409         ikm,
410         secret: secret.to_vec(),
411         public: public.to_vec(),
412         seal_tests,
413         export_tests,
414     }
415 }
416 
417 #[derive(serde::Deserialize, serde::Serialize)]
418 struct HkdfTestCase {
419     #[serde(with = "hex::serde")]
420     pub ikm: Vec<u8>,
421     #[serde(with = "hex::serde")]
422     pub salt: Vec<u8>,
423     #[serde(with = "hex::serde")]
424     pub info: Vec<u8>,
425     pub len: usize,
426     #[serde(with = "hex::serde")]
427     pub prk: Vec<u8>,
428     #[serde(with = "hex::serde")]
429     pub okm: Vec<u8>,
430 }
431 
432 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_hkdf_tests<C: CipherSuiteProvider>(cs: &C, test_cases: Vec<HkdfTestCase>)433 async fn verify_hkdf_tests<C: CipherSuiteProvider>(cs: &C, test_cases: Vec<HkdfTestCase>) {
434     for case in test_cases {
435         let extracted = cs.kdf_extract(&case.salt, &case.ikm).await.unwrap();
436 
437         assert_eq!(extracted.to_vec(), case.prk);
438 
439         let expanded = cs
440             .kdf_expand(&case.prk, &case.info, case.len)
441             .await
442             .unwrap();
443 
444         assert_eq!(expanded.to_vec(), case.okm);
445     }
446 }
447 
448 #[cfg(all(not(mls_build_async), not(target_arch = "wasm32"), feature = "std"))]
449 #[cfg_attr(coverage_nightly, coverage(off))]
generate_hkdf_tests<C: CipherSuiteProvider>(cs: &C) -> Vec<HkdfTestCase>450 fn generate_hkdf_tests<C: CipherSuiteProvider>(cs: &C) -> Vec<HkdfTestCase> {
451     let iter = DATA_SIZES.iter().copied();
452 
453     let iter = iter
454         .clone()
455         .skip(1)
456         .cartesian_product(iter.clone())
457         .cartesian_product(iter.clone())
458         .cartesian_product(iter.skip(1));
459 
460     iter.map(|(((ikm_size, salt_size), info_size), len)| {
461         let ikm = cs.random_bytes_vec(ikm_size).unwrap();
462         let salt = cs.random_bytes_vec(salt_size).unwrap();
463         let info = cs.random_bytes_vec(info_size).unwrap();
464 
465         let prk = cs.kdf_extract(&salt, &ikm).unwrap().to_vec();
466         let okm = cs.kdf_expand(&prk, &info, len).unwrap().to_vec();
467 
468         HkdfTestCase {
469             ikm,
470             salt,
471             info,
472             len,
473             prk,
474             okm,
475         }
476     })
477     .collect()
478 }
479 
480 // Test vectors from RFC 4231
481 #[derive(serde::Deserialize, serde::Serialize)]
482 struct MacTestCase {
483     #[serde(with = "hex::serde")]
484     key: Vec<u8>,
485     #[serde(with = "hex::serde")]
486     data: Vec<u8>,
487     #[serde(with = "hex::serde")]
488     tag: Vec<u8>,
489 }
490 
491 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_mac_tests<C: CipherSuiteProvider>(cs: &C, test_cases: Vec<MacTestCase>)492 async fn verify_mac_tests<C: CipherSuiteProvider>(cs: &C, test_cases: Vec<MacTestCase>) {
493     for case in test_cases {
494         let computed = cs.mac(&case.key, &case.data).await.unwrap();
495         assert_eq!(computed, case.tag);
496     }
497 }
498 
499 #[derive(serde::Deserialize, serde::Serialize)]
500 struct HashTestCase {
501     #[serde(with = "hex::serde")]
502     input: Vec<u8>,
503     #[serde(with = "hex::serde")]
504     output: Vec<u8>,
505 }
506 
507 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_hash_tests<C: CipherSuiteProvider>(cs: &C, test_cases: Vec<HashTestCase>)508 async fn verify_hash_tests<C: CipherSuiteProvider>(cs: &C, test_cases: Vec<HashTestCase>) {
509     for case in test_cases {
510         let computed = cs.hash(&case.input).await.unwrap();
511         assert_eq!(computed, case.output);
512     }
513 }
514 
515 mod hpke_rfc_conformance {
516     use alloc::vec::Vec;
517 
518     use crate::crypto::{CipherSuite, HpkeContextR, HpkeContextS, HpkeModeId};
519 
520     #[derive(serde::Deserialize, Debug, Clone)]
521     pub struct TestCaseAlgo {
522         pub kem_id: u16,
523         pub kdf_id: u16,
524         pub aead_id: u16,
525         pub mode: u8,
526     }
527 
528     impl TestCaseAlgo {
cipher_suite(&self) -> Option<CipherSuite>529         fn cipher_suite(&self) -> Option<CipherSuite> {
530             if ![HpkeModeId::Base as u8, HpkeModeId::Psk as u8].contains(&self.mode) {
531                 return None;
532             }
533 
534             match (self.kem_id, self.kdf_id, self.aead_id) {
535                 (0x0010, 0x0001, 0x0001) => Some(CipherSuite::P256_AES128),
536                 (0x0011, 0x0002, 0x0002) => Some(CipherSuite::P384_AES256),
537                 (0x0012, 0x0003, 0x0002) => Some(CipherSuite::P521_AES256),
538                 (0x0020, 0x0001, 0x0001) => Some(CipherSuite::CURVE25519_AES128),
539                 (0x0020, 0x0001, 0x0003) => Some(CipherSuite::CURVE25519_CHACHA),
540                 (0x0021, 0x0003, 0x0002) => Some(CipherSuite::CURVE448_AES256),
541                 (0x0021, 0x0003, 0x0003) => Some(CipherSuite::CURVE448_CHACHA),
542                 _ => None,
543             }
544         }
545     }
546 
547     #[derive(serde::Deserialize, Debug)]
548     struct TestCase {
549         #[serde(flatten)]
550         algo: TestCaseAlgo,
551         #[serde(with = "hex::serde", rename(deserialize = "pkRm"))]
552         pk_rm: Vec<u8>,
553         #[serde(with = "hex::serde", rename(deserialize = "skRm"))]
554         sk_rm: Vec<u8>,
555         #[serde(with = "hex::serde", rename(deserialize = "ikmE"))]
556         ikm_e: Vec<u8>,
557         #[serde(with = "hex::serde")]
558         shared_secret: Vec<u8>,
559         #[serde(with = "hex::serde")]
560         enc: Vec<u8>,
561         #[serde(with = "hex::serde")]
562         exporter_secret: Vec<u8>,
563         #[serde(with = "hex::serde")]
564         base_nonce: Vec<u8>,
565         #[serde(with = "hex::serde")]
566         key: Vec<u8>,
567         encryptions: Vec<EncryptionTestCase>,
568         exports: Vec<ExportTestCase>,
569     }
570 
571     #[derive(serde::Deserialize, Debug)]
572     struct EncryptionTestCase {
573         #[serde(with = "hex::serde", rename = "pt")]
574         plaintext: Vec<u8>,
575         #[serde(with = "hex::serde")]
576         aad: Vec<u8>,
577         #[serde(with = "hex::serde", rename = "ct")]
578         ciphertext: Vec<u8>,
579     }
580 
581     #[derive(serde::Deserialize, Debug)]
582     struct ExportTestCase {
583         #[serde(with = "hex::serde")]
584         exporter_context: Vec<u8>,
585         #[serde(rename = "L")]
586         length: usize,
587         #[serde(with = "hex::serde")]
588         exported_value: Vec<u8>,
589     }
590 
591     #[cfg(any(target_arch = "wasm32", not(feature = "std")))]
get_test_cases() -> Vec<TestCase>592     fn get_test_cases() -> Vec<TestCase> {
593         let bytes = include_bytes!(concat!(
594             env!("CARGO_MANIFEST_DIR"),
595             "/test_data/test_hpke.json"
596         ));
597 
598         serde_json::from_slice(bytes).unwrap()
599     }
600 
601     #[cfg(all(not(target_arch = "wasm32"), feature = "std"))]
get_test_cases() -> Vec<TestCase>602     fn get_test_cases() -> Vec<TestCase> {
603         let path = concat!(env!("CARGO_MANIFEST_DIR"), "/test_data/test_hpke.json");
604 
605         serde_json::from_slice(&std::fs::read(path).unwrap()).unwrap()
606     }
607 
608     pub struct EncapOutput {
609         pub enc: Vec<u8>,
610         pub shared_secret: Vec<u8>,
611     }
612 
613     impl EncapOutput {
new(enc: Vec<u8>, shared_secret: Vec<u8>) -> Self614         pub fn new(enc: Vec<u8>, shared_secret: Vec<u8>) -> Self {
615             Self { enc, shared_secret }
616         }
617     }
618 
619     pub trait TestHpke {
620         type ContextS: HpkeContextS;
621         type ContextR: HpkeContextR;
622 
hpke_context( &self, key: Vec<u8>, base_nonce: Vec<u8>, exporter_secret: Vec<u8>, ) -> (Self::ContextS, Self::ContextR)623         fn hpke_context(
624             &self,
625             key: Vec<u8>,
626             base_nonce: Vec<u8>,
627             exporter_secret: Vec<u8>,
628         ) -> (Self::ContextS, Self::ContextR);
629 
encap(&mut self, ikm_e: Vec<u8>, pk_rm: Vec<u8>) -> EncapOutput630         fn encap(&mut self, ikm_e: Vec<u8>, pk_rm: Vec<u8>) -> EncapOutput;
decap(&mut self, enc: Vec<u8>, sk_rm: Vec<u8>, pk_rm: Vec<u8>) -> Vec<u8>631         fn decap(&mut self, enc: Vec<u8>, sk_rm: Vec<u8>, pk_rm: Vec<u8>) -> Vec<u8>;
632     }
633 
634     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_hpke_context_tests<C: TestHpke>(hpke: &C, cipher_suite: CipherSuite)635     pub async fn verify_hpke_context_tests<C: TestHpke>(hpke: &C, cipher_suite: CipherSuite) {
636         for test_case in get_test_cases()
637             .into_iter()
638             .filter(|tc| matches!(tc.algo.cipher_suite(), Some(c) if c == cipher_suite))
639         {
640             let (mut context_s, mut context_r) = hpke.hpke_context(
641                 test_case.key,
642                 test_case.base_nonce,
643                 test_case.exporter_secret,
644             );
645 
646             for enc_test_case in test_case.encryptions {
647                 // Encrypt
648                 let ct = context_s
649                     .seal(Some(&enc_test_case.aad), &enc_test_case.plaintext)
650                     .await
651                     .unwrap();
652 
653                 assert_eq!(ct, enc_test_case.ciphertext);
654 
655                 // Decrypt
656                 let pt = context_r.open(Some(&enc_test_case.aad), &ct).await.unwrap();
657 
658                 assert_eq!(pt, enc_test_case.plaintext);
659             }
660 
661             for test in test_case.exports {
662                 let exported_s = context_s.export(&test.exporter_context, test.length).await;
663                 assert_eq!(exported_s.unwrap(), test.exported_value);
664 
665                 let exported_r = context_r.export(&test.exporter_context, test.length).await;
666                 assert_eq!(exported_r.unwrap(), test.exported_value);
667             }
668         }
669     }
670 
671     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_hpke_encap_tests<C: TestHpke>(hpke: &mut C, cipher_suite: CipherSuite)672     pub async fn verify_hpke_encap_tests<C: TestHpke>(hpke: &mut C, cipher_suite: CipherSuite) {
673         for test_case in get_test_cases()
674             .into_iter()
675             .filter(|tc| matches!(tc.algo.cipher_suite(), Some(c) if c == cipher_suite))
676         {
677             let out = hpke.encap(test_case.ikm_e, test_case.pk_rm.clone());
678 
679             assert_eq!(&out.enc, &test_case.enc);
680             assert_eq!(&out.shared_secret, &test_case.shared_secret);
681 
682             let shared_secret = hpke.decap(test_case.enc, test_case.sk_rm, test_case.pk_rm);
683 
684             assert_eq!(shared_secret, test_case.shared_secret);
685         }
686     }
687 }
688