1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec::Vec;
6 use core::fmt::{self, Debug};
7 use mls_rs_codec::{MlsEncode, MlsSize};
8 use mls_rs_core::{
9     crypto::{CipherSuiteProvider, HpkeCiphertext, HpkePublicKey, HpkeSecretKey},
10     error::IntoAnyError,
11 };
12 use zeroize::Zeroizing;
13 
14 use crate::client::MlsError;
15 
16 #[derive(Clone, MlsSize, MlsEncode)]
17 struct EncryptContext<'a> {
18     #[mls_codec(with = "mls_rs_codec::byte_vec")]
19     label: Vec<u8>,
20     #[mls_codec(with = "mls_rs_codec::byte_vec")]
21     context: &'a [u8],
22 }
23 
24 impl Debug for EncryptContext<'_> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result25     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26         f.debug_struct("EncryptContext")
27             .field("label", &mls_rs_core::debug::pretty_bytes(&self.label))
28             .field("context", &mls_rs_core::debug::pretty_bytes(self.context))
29             .finish()
30     }
31 }
32 
33 impl<'a> EncryptContext<'a> {
new(label: &str, context: &'a [u8]) -> Self34     pub fn new(label: &str, context: &'a [u8]) -> Self {
35         Self {
36             label: [b"MLS 1.0 ", label.as_bytes()].concat(),
37             context,
38         }
39     }
40 }
41 
42 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
43 #[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
44 #[cfg_attr(
45     all(not(target_arch = "wasm32"), mls_build_async),
46     maybe_async::must_be_async
47 )]
48 
49 pub(crate) trait HpkeEncryptable: Sized {
50     const ENCRYPT_LABEL: &'static str;
51 
encrypt<P: CipherSuiteProvider>( &self, cipher_suite_provider: &P, public_key: &HpkePublicKey, context: &[u8], ) -> Result<HpkeCiphertext, MlsError>52     async fn encrypt<P: CipherSuiteProvider>(
53         &self,
54         cipher_suite_provider: &P,
55         public_key: &HpkePublicKey,
56         context: &[u8],
57     ) -> Result<HpkeCiphertext, MlsError> {
58         let context = EncryptContext::new(Self::ENCRYPT_LABEL, context)
59             .mls_encode_to_vec()
60             .map(Zeroizing::new)?;
61 
62         let content = self.get_bytes().map(Zeroizing::new)?;
63 
64         cipher_suite_provider
65             .hpke_seal(public_key, &context, None, &content)
66             .await
67             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
68     }
69 
decrypt<P: CipherSuiteProvider>( cipher_suite_provider: &P, secret_key: &HpkeSecretKey, public_key: &HpkePublicKey, context: &[u8], ciphertext: &HpkeCiphertext, ) -> Result<Self, MlsError>70     async fn decrypt<P: CipherSuiteProvider>(
71         cipher_suite_provider: &P,
72         secret_key: &HpkeSecretKey,
73         public_key: &HpkePublicKey,
74         context: &[u8],
75         ciphertext: &HpkeCiphertext,
76     ) -> Result<Self, MlsError> {
77         let context = EncryptContext::new(Self::ENCRYPT_LABEL, context).mls_encode_to_vec()?;
78 
79         let plaintext = cipher_suite_provider
80             .hpke_open(ciphertext, secret_key, public_key, &context, None)
81             .await
82             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
83 
84         Self::from_bytes(plaintext.to_vec())
85     }
86 
from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError>87     fn from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError>;
get_bytes(&self) -> Result<Vec<u8>, MlsError>88     fn get_bytes(&self) -> Result<Vec<u8>, MlsError>;
89 }
90 
91 #[cfg(test)]
92 pub(crate) mod test_utils {
93     use alloc::{string::String, vec::Vec};
94     use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
95     use mls_rs_core::crypto::{CipherSuiteProvider, HpkeCiphertext};
96 
97     use crate::{client::MlsError, crypto::test_utils::try_test_cipher_suite_provider};
98 
99     use super::HpkeEncryptable;
100 
101     #[derive(Debug, serde::Serialize, serde::Deserialize)]
102     pub struct HpkeInteropTestCase {
103         #[serde(with = "hex::serde", rename = "priv")]
104         secret: Vec<u8>,
105         #[serde(with = "hex::serde", rename = "pub")]
106         public: Vec<u8>,
107         label: String,
108         #[serde(with = "hex::serde")]
109         context: Vec<u8>,
110         #[serde(with = "hex::serde")]
111         plaintext: Vec<u8>,
112         #[serde(with = "hex::serde")]
113         kem_output: Vec<u8>,
114         #[serde(with = "hex::serde")]
115         ciphertext: Vec<u8>,
116     }
117 
118     #[derive(Debug, serde::Serialize, serde::Deserialize)]
119     pub struct InteropTestCase {
120         cipher_suite: u16,
121         encrypt_with_label: HpkeInteropTestCase,
122     }
123 
124     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_basic_crypto_test_vectors()125     async fn test_basic_crypto_test_vectors() {
126         // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/crypto-basics.json
127         let test_cases: Vec<InteropTestCase> =
128             load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
129 
130         for test_case in test_cases {
131             if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
132                 test_case.encrypt_with_label.verify(&cs).await
133             }
134         }
135     }
136 
137     #[derive(Clone, Debug, MlsSize, MlsEncode, MlsDecode)]
138     struct TestEncryptable(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>);
139 
140     impl HpkeEncryptable for TestEncryptable {
141         const ENCRYPT_LABEL: &'static str = "EncryptWithLabel";
142 
from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError>143         fn from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError> {
144             Ok(Self(bytes))
145         }
146 
147         #[cfg_attr(coverage_nightly, coverage(off))]
get_bytes(&self) -> Result<Vec<u8>, MlsError>148         fn get_bytes(&self) -> Result<Vec<u8>, MlsError> {
149             Ok(self.0.clone())
150         }
151     }
152 
153     impl HpkeInteropTestCase {
154         #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify<P: CipherSuiteProvider>(&self, cs: &P)155         pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) {
156             let secret = self.secret.clone().into();
157             let public = self.public.clone().into();
158 
159             let ciphertext = HpkeCiphertext {
160                 kem_output: self.kem_output.clone(),
161                 ciphertext: self.ciphertext.clone(),
162             };
163 
164             let computed_plaintext =
165                 TestEncryptable::decrypt(cs, &secret, &public, &self.context, &ciphertext)
166                     .await
167                     .unwrap();
168 
169             assert_eq!(&computed_plaintext.0, &self.plaintext)
170         }
171     }
172 }
173