1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use crate::client::MlsError;
6 use crate::extension::ExternalPubExt;
7 use crate::group::{GroupContext, MembershipTag};
8 use crate::psk::secret::PskSecret;
9 #[cfg(feature = "psk")]
10 use crate::psk::PreSharedKey;
11 use crate::tree_kem::path_secret::PathSecret;
12 use crate::CipherSuiteProvider;
13 
14 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
15 use crate::group::SecretTree;
16 
17 use alloc::vec;
18 use alloc::vec::Vec;
19 use core::fmt::{self, Debug};
20 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
21 use mls_rs_core::error::IntoAnyError;
22 use zeroize::Zeroizing;
23 
24 use crate::crypto::{HpkeContextR, HpkeContextS, HpkePublicKey, HpkeSecretKey};
25 
26 use super::epoch::{EpochSecrets, SenderDataSecret};
27 use super::message_signature::AuthenticatedContent;
28 
29 #[derive(Clone, PartialEq, Eq, Default, MlsEncode, MlsDecode, MlsSize)]
30 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31 pub struct KeySchedule {
32     #[mls_codec(with = "mls_rs_codec::byte_vec")]
33     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
34     exporter_secret: Zeroizing<Vec<u8>>,
35     #[mls_codec(with = "mls_rs_codec::byte_vec")]
36     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
37     pub authentication_secret: Zeroizing<Vec<u8>>,
38     #[mls_codec(with = "mls_rs_codec::byte_vec")]
39     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
40     external_secret: Zeroizing<Vec<u8>>,
41     #[mls_codec(with = "mls_rs_codec::byte_vec")]
42     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
43     membership_key: Zeroizing<Vec<u8>>,
44     init_secret: InitSecret,
45 }
46 
47 impl Debug for KeySchedule {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result48     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49         f.debug_struct("KeySchedule")
50             .field(
51                 "exporter_secret",
52                 &mls_rs_core::debug::pretty_bytes(&self.exporter_secret),
53             )
54             .field(
55                 "authentication_secret",
56                 &mls_rs_core::debug::pretty_bytes(&self.authentication_secret),
57             )
58             .field(
59                 "external_secret",
60                 &mls_rs_core::debug::pretty_bytes(&self.external_secret),
61             )
62             .field(
63                 "membership_key",
64                 &mls_rs_core::debug::pretty_bytes(&self.membership_key),
65             )
66             .field("init_secret", &self.init_secret)
67             .finish()
68     }
69 }
70 
71 pub(crate) struct KeyScheduleDerivationResult {
72     pub(crate) key_schedule: KeySchedule,
73     pub(crate) confirmation_key: Zeroizing<Vec<u8>>,
74     pub(crate) joiner_secret: JoinerSecret,
75     pub(crate) epoch_secrets: EpochSecrets,
76 }
77 
78 impl KeySchedule {
new(init_secret: InitSecret) -> Self79     pub fn new(init_secret: InitSecret) -> Self {
80         KeySchedule {
81             init_secret,
82             ..Default::default()
83         }
84     }
85 
86     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
derive_for_external<P: CipherSuiteProvider>( &self, kem_output: &[u8], cipher_suite: &P, ) -> Result<KeySchedule, MlsError>87     pub async fn derive_for_external<P: CipherSuiteProvider>(
88         &self,
89         kem_output: &[u8],
90         cipher_suite: &P,
91     ) -> Result<KeySchedule, MlsError> {
92         let (secret, public) = self.get_external_key_pair(cipher_suite).await?;
93 
94         let init_secret =
95             InitSecret::decode_for_external(cipher_suite, kem_output, &secret, &public).await?;
96 
97         Ok(KeySchedule::new(init_secret))
98     }
99 
100     /// Returns the derived epoch as well as the joiner secret required for building welcome
101     /// messages
102     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
from_key_schedule<P: CipherSuiteProvider>( last_key_schedule: &KeySchedule, commit_secret: &PathSecret, context: &GroupContext, #[cfg(any(feature = "secret_tree_access", feature = "private_message"))] secret_tree_size: u32, psk_secret: &PskSecret, cipher_suite_provider: &P, ) -> Result<KeyScheduleDerivationResult, MlsError>103     pub(crate) async fn from_key_schedule<P: CipherSuiteProvider>(
104         last_key_schedule: &KeySchedule,
105         commit_secret: &PathSecret,
106         context: &GroupContext,
107         #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
108         secret_tree_size: u32,
109         psk_secret: &PskSecret,
110         cipher_suite_provider: &P,
111     ) -> Result<KeyScheduleDerivationResult, MlsError> {
112         let joiner_seed = cipher_suite_provider
113             .kdf_extract(&last_key_schedule.init_secret.0, commit_secret)
114             .await
115             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
116 
117         let joiner_secret = kdf_expand_with_label(
118             cipher_suite_provider,
119             &joiner_seed,
120             b"joiner",
121             &context.mls_encode_to_vec()?,
122             None,
123         )
124         .await?
125         .into();
126 
127         let key_schedule_result = Self::from_joiner(
128             cipher_suite_provider,
129             &joiner_secret,
130             context,
131             #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
132             secret_tree_size,
133             psk_secret,
134         )
135         .await?;
136 
137         Ok(KeyScheduleDerivationResult {
138             key_schedule: key_schedule_result.key_schedule,
139             confirmation_key: key_schedule_result.confirmation_key,
140             joiner_secret,
141             epoch_secrets: key_schedule_result.epoch_secrets,
142         })
143     }
144 
145     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
from_joiner<P: CipherSuiteProvider>( cipher_suite_provider: &P, joiner_secret: &JoinerSecret, context: &GroupContext, #[cfg(any(feature = "secret_tree_access", feature = "private_message"))] secret_tree_size: u32, psk_secret: &PskSecret, ) -> Result<KeyScheduleDerivationResult, MlsError>146     pub(crate) async fn from_joiner<P: CipherSuiteProvider>(
147         cipher_suite_provider: &P,
148         joiner_secret: &JoinerSecret,
149         context: &GroupContext,
150         #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
151         secret_tree_size: u32,
152         psk_secret: &PskSecret,
153     ) -> Result<KeyScheduleDerivationResult, MlsError> {
154         let epoch_seed =
155             get_pre_epoch_secret(cipher_suite_provider, psk_secret, joiner_secret).await?;
156         let context = context.mls_encode_to_vec()?;
157 
158         let epoch_secret =
159             kdf_expand_with_label(cipher_suite_provider, &epoch_seed, b"epoch", &context, None)
160                 .await?;
161 
162         Self::from_epoch_secret(
163             cipher_suite_provider,
164             &epoch_secret,
165             #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
166             secret_tree_size,
167         )
168         .await
169     }
170 
171     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
from_random_epoch_secret<P: CipherSuiteProvider>( cipher_suite_provider: &P, #[cfg(any(feature = "secret_tree_access", feature = "private_message"))] secret_tree_size: u32, ) -> Result<KeyScheduleDerivationResult, MlsError>172     pub(crate) async fn from_random_epoch_secret<P: CipherSuiteProvider>(
173         cipher_suite_provider: &P,
174         #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
175         secret_tree_size: u32,
176     ) -> Result<KeyScheduleDerivationResult, MlsError> {
177         let epoch_secret = cipher_suite_provider
178             .random_bytes_vec(cipher_suite_provider.kdf_extract_size())
179             .map(Zeroizing::new)
180             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
181 
182         Self::from_epoch_secret(
183             cipher_suite_provider,
184             &epoch_secret,
185             #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
186             secret_tree_size,
187         )
188         .await
189     }
190 
191     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
from_epoch_secret<P: CipherSuiteProvider>( cipher_suite_provider: &P, epoch_secret: &[u8], #[cfg(any(feature = "secret_tree_access", feature = "private_message"))] secret_tree_size: u32, ) -> Result<KeyScheduleDerivationResult, MlsError>192     async fn from_epoch_secret<P: CipherSuiteProvider>(
193         cipher_suite_provider: &P,
194         epoch_secret: &[u8],
195         #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
196         secret_tree_size: u32,
197     ) -> Result<KeyScheduleDerivationResult, MlsError> {
198         let secrets_producer = SecretsProducer::new(cipher_suite_provider, epoch_secret);
199 
200         let epoch_secrets = EpochSecrets {
201             #[cfg(feature = "psk")]
202             resumption_secret: PreSharedKey::from(secrets_producer.derive(b"resumption").await?),
203             sender_data_secret: SenderDataSecret::from(
204                 secrets_producer.derive(b"sender data").await?,
205             ),
206             #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
207             secret_tree: SecretTree::new(
208                 secret_tree_size,
209                 secrets_producer.derive(b"encryption").await?,
210             ),
211         };
212 
213         let key_schedule = Self {
214             exporter_secret: secrets_producer.derive(b"exporter").await?,
215             authentication_secret: secrets_producer.derive(b"authentication").await?,
216             external_secret: secrets_producer.derive(b"external").await?,
217             membership_key: secrets_producer.derive(b"membership").await?,
218             init_secret: InitSecret(secrets_producer.derive(b"init").await?),
219         };
220 
221         Ok(KeyScheduleDerivationResult {
222             key_schedule,
223             confirmation_key: secrets_producer.derive(b"confirm").await?,
224             joiner_secret: Zeroizing::new(vec![]).into(),
225             epoch_secrets,
226         })
227     }
228 
229     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
export_secret<P: CipherSuiteProvider>( &self, label: &[u8], context: &[u8], len: usize, cipher_suite: &P, ) -> Result<Zeroizing<Vec<u8>>, MlsError>230     pub async fn export_secret<P: CipherSuiteProvider>(
231         &self,
232         label: &[u8],
233         context: &[u8],
234         len: usize,
235         cipher_suite: &P,
236     ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
237         let secret = kdf_derive_secret(cipher_suite, &self.exporter_secret, label).await?;
238 
239         let context_hash = cipher_suite
240             .hash(context)
241             .await
242             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
243 
244         kdf_expand_with_label(cipher_suite, &secret, b"exported", &context_hash, Some(len)).await
245     }
246 
247     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_membership_tag<P: CipherSuiteProvider>( &self, content: &AuthenticatedContent, context: &GroupContext, cipher_suite_provider: &P, ) -> Result<MembershipTag, MlsError>248     pub async fn get_membership_tag<P: CipherSuiteProvider>(
249         &self,
250         content: &AuthenticatedContent,
251         context: &GroupContext,
252         cipher_suite_provider: &P,
253     ) -> Result<MembershipTag, MlsError> {
254         MembershipTag::create(
255             content,
256             context,
257             &self.membership_key,
258             cipher_suite_provider,
259         )
260         .await
261     }
262 
263     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_external_key_pair<P: CipherSuiteProvider>( &self, cipher_suite: &P, ) -> Result<(HpkeSecretKey, HpkePublicKey), MlsError>264     pub async fn get_external_key_pair<P: CipherSuiteProvider>(
265         &self,
266         cipher_suite: &P,
267     ) -> Result<(HpkeSecretKey, HpkePublicKey), MlsError> {
268         cipher_suite
269             .kem_derive(&self.external_secret)
270             .await
271             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
272     }
273 
274     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_external_key_pair_ext<P: CipherSuiteProvider>( &self, cipher_suite: &P, ) -> Result<ExternalPubExt, MlsError>275     pub async fn get_external_key_pair_ext<P: CipherSuiteProvider>(
276         &self,
277         cipher_suite: &P,
278     ) -> Result<ExternalPubExt, MlsError> {
279         let (_external_secret, external_pub) = self.get_external_key_pair(cipher_suite).await?;
280 
281         Ok(ExternalPubExt { external_pub })
282     }
283 }
284 
285 #[derive(MlsEncode, MlsSize)]
286 struct Label<'a> {
287     length: u16,
288     #[mls_codec(with = "mls_rs_codec::byte_vec")]
289     label: Vec<u8>,
290     #[mls_codec(with = "mls_rs_codec::byte_vec")]
291     context: &'a [u8],
292 }
293 
294 impl<'a> Label<'a> {
new(length: u16, label: &'a [u8], context: &'a [u8]) -> Self295     fn new(length: u16, label: &'a [u8], context: &'a [u8]) -> Self {
296         Self {
297             length,
298             label: [b"MLS 1.0 ", label].concat(),
299             context,
300         }
301     }
302 }
303 
304 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
kdf_expand_with_label<P: CipherSuiteProvider>( cipher_suite_provider: &P, secret: &[u8], label: &[u8], context: &[u8], len: Option<usize>, ) -> Result<Zeroizing<Vec<u8>>, MlsError>305 pub(crate) async fn kdf_expand_with_label<P: CipherSuiteProvider>(
306     cipher_suite_provider: &P,
307     secret: &[u8],
308     label: &[u8],
309     context: &[u8],
310     len: Option<usize>,
311 ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
312     let extract_size = cipher_suite_provider.kdf_extract_size();
313     let len = len.unwrap_or(extract_size);
314     let label = Label::new(len as u16, label, context);
315 
316     cipher_suite_provider
317         .kdf_expand(secret, &label.mls_encode_to_vec()?, len)
318         .await
319         .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
320 }
321 
322 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
kdf_derive_secret<P: CipherSuiteProvider>( cipher_suite_provider: &P, secret: &[u8], label: &[u8], ) -> Result<Zeroizing<Vec<u8>>, MlsError>323 pub(crate) async fn kdf_derive_secret<P: CipherSuiteProvider>(
324     cipher_suite_provider: &P,
325     secret: &[u8],
326     label: &[u8],
327 ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
328     kdf_expand_with_label(cipher_suite_provider, secret, label, &[], None).await
329 }
330 
331 #[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
332 pub(crate) struct JoinerSecret(#[mls_codec(with = "mls_rs_codec::byte_vec")] Zeroizing<Vec<u8>>);
333 
334 impl Debug for JoinerSecret {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result335     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
336         mls_rs_core::debug::pretty_bytes(&self.0)
337             .named("JoinerSecret")
338             .fmt(f)
339     }
340 }
341 
342 impl From<Zeroizing<Vec<u8>>> for JoinerSecret {
from(bytes: Zeroizing<Vec<u8>>) -> Self343     fn from(bytes: Zeroizing<Vec<u8>>) -> Self {
344         Self(bytes)
345     }
346 }
347 
348 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_pre_epoch_secret<P: CipherSuiteProvider>( cipher_suite_provider: &P, psk_secret: &PskSecret, joiner_secret: &JoinerSecret, ) -> Result<Zeroizing<Vec<u8>>, MlsError>349 pub(crate) async fn get_pre_epoch_secret<P: CipherSuiteProvider>(
350     cipher_suite_provider: &P,
351     psk_secret: &PskSecret,
352     joiner_secret: &JoinerSecret,
353 ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
354     cipher_suite_provider
355         .kdf_extract(&joiner_secret.0, psk_secret)
356         .await
357         .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
358 }
359 
360 struct SecretsProducer<'a, P: CipherSuiteProvider> {
361     cipher_suite_provider: &'a P,
362     epoch_secret: &'a [u8],
363 }
364 
365 impl<'a, P: CipherSuiteProvider> SecretsProducer<'a, P> {
new(cipher_suite_provider: &'a P, epoch_secret: &'a [u8]) -> Self366     fn new(cipher_suite_provider: &'a P, epoch_secret: &'a [u8]) -> Self {
367         Self {
368             cipher_suite_provider,
369             epoch_secret,
370         }
371     }
372 
373     // TODO document somewhere in the crypto provider that the RFC defines the length of all secrets as
374     // KDF extract size but then inputs secrets as MAC keys etc, therefore, we require that these
375     // lengths match in the crypto provider
376     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
derive(&self, label: &[u8]) -> Result<Zeroizing<Vec<u8>>, MlsError>377     async fn derive(&self, label: &[u8]) -> Result<Zeroizing<Vec<u8>>, MlsError> {
378         kdf_derive_secret(self.cipher_suite_provider, self.epoch_secret, label).await
379     }
380 }
381 
382 const EXPORTER_CONTEXT: &[u8] = b"MLS 1.0 external init secret";
383 
384 #[derive(Clone, Eq, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)]
385 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
386 pub struct InitSecret(
387     #[mls_codec(with = "mls_rs_codec::byte_vec")]
388     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
389     Zeroizing<Vec<u8>>,
390 );
391 
392 impl Debug for InitSecret {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result393     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
394         mls_rs_core::debug::pretty_bytes(&self.0)
395             .named("InitSecret")
396             .fmt(f)
397     }
398 }
399 
400 impl InitSecret {
401     /// Returns init secret and KEM output to be used when creating an external commit.
402     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
encode_for_external<P: CipherSuiteProvider>( cipher_suite: &P, external_pub: &HpkePublicKey, ) -> Result<(Self, Vec<u8>), MlsError>403     pub async fn encode_for_external<P: CipherSuiteProvider>(
404         cipher_suite: &P,
405         external_pub: &HpkePublicKey,
406     ) -> Result<(Self, Vec<u8>), MlsError> {
407         let (kem_output, context) = cipher_suite
408             .hpke_setup_s(external_pub, &[])
409             .await
410             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
411 
412         let init_secret = context
413             .export(EXPORTER_CONTEXT, cipher_suite.kdf_extract_size())
414             .await
415             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
416 
417         Ok((InitSecret(Zeroizing::new(init_secret)), kem_output))
418     }
419 
420     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
decode_for_external<P: CipherSuiteProvider>( cipher_suite: &P, kem_output: &[u8], external_secret: &HpkeSecretKey, external_pub: &HpkePublicKey, ) -> Result<Self, MlsError>421     pub async fn decode_for_external<P: CipherSuiteProvider>(
422         cipher_suite: &P,
423         kem_output: &[u8],
424         external_secret: &HpkeSecretKey,
425         external_pub: &HpkePublicKey,
426     ) -> Result<Self, MlsError> {
427         let context = cipher_suite
428             .hpke_setup_r(kem_output, external_secret, external_pub, &[])
429             .await
430             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
431 
432         context
433             .export(EXPORTER_CONTEXT, cipher_suite.kdf_extract_size())
434             .await
435             .map(Zeroizing::new)
436             .map(InitSecret)
437             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
438     }
439 }
440 
441 pub(crate) struct WelcomeSecret<'a, P: CipherSuiteProvider> {
442     cipher_suite: &'a P,
443     key: Zeroizing<Vec<u8>>,
444     nonce: Zeroizing<Vec<u8>>,
445 }
446 
447 impl<'a, P: CipherSuiteProvider> WelcomeSecret<'a, P> {
448     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
from_joiner_secret( cipher_suite: &'a P, joiner_secret: &JoinerSecret, psk_secret: &PskSecret, ) -> Result<WelcomeSecret<'a, P>, MlsError>449     pub(crate) async fn from_joiner_secret(
450         cipher_suite: &'a P,
451         joiner_secret: &JoinerSecret,
452         psk_secret: &PskSecret,
453     ) -> Result<WelcomeSecret<'a, P>, MlsError> {
454         let welcome_secret = get_welcome_secret(cipher_suite, joiner_secret, psk_secret).await?;
455 
456         let key_len = cipher_suite.aead_key_size();
457         let key = kdf_expand_with_label(cipher_suite, &welcome_secret, b"key", &[], Some(key_len))
458             .await?;
459 
460         let nonce_len = cipher_suite.aead_nonce_size();
461 
462         let nonce = kdf_expand_with_label(
463             cipher_suite,
464             &welcome_secret,
465             b"nonce",
466             &[],
467             Some(nonce_len),
468         )
469         .await?;
470 
471         Ok(Self {
472             cipher_suite,
473             key,
474             nonce,
475         })
476     }
477 
478     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, MlsError>479     pub(crate) async fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, MlsError> {
480         self.cipher_suite
481             .aead_seal(&self.key, plaintext, None, &self.nonce)
482             .await
483             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
484     }
485 
486     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
decrypt(&self, ciphertext: &[u8]) -> Result<Zeroizing<Vec<u8>>, MlsError>487     pub(crate) async fn decrypt(&self, ciphertext: &[u8]) -> Result<Zeroizing<Vec<u8>>, MlsError> {
488         self.cipher_suite
489             .aead_open(&self.key, ciphertext, None, &self.nonce)
490             .await
491             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
492     }
493 }
494 
495 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_welcome_secret<P: CipherSuiteProvider>( cipher_suite: &P, joiner_secret: &JoinerSecret, psk_secret: &PskSecret, ) -> Result<Zeroizing<Vec<u8>>, MlsError>496 async fn get_welcome_secret<P: CipherSuiteProvider>(
497     cipher_suite: &P,
498     joiner_secret: &JoinerSecret,
499     psk_secret: &PskSecret,
500 ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
501     let epoch_seed = get_pre_epoch_secret(cipher_suite, psk_secret, joiner_secret).await?;
502     kdf_derive_secret(cipher_suite, &epoch_seed, b"welcome").await
503 }
504 
505 #[cfg(test)]
506 pub(crate) mod test_utils {
507     use alloc::vec;
508     use alloc::vec::Vec;
509     use mls_rs_core::crypto::CipherSuiteProvider;
510     use zeroize::Zeroizing;
511 
512     use crate::{cipher_suite::CipherSuite, crypto::test_utils::test_cipher_suite_provider};
513 
514     use super::{InitSecret, JoinerSecret, KeySchedule};
515 
516     #[cfg(all(feature = "rfc_compliant", not(mls_build_async)))]
517     use mls_rs_core::error::IntoAnyError;
518 
519     #[cfg(all(feature = "rfc_compliant", not(mls_build_async)))]
520     use super::MlsError;
521 
522     impl From<JoinerSecret> for Vec<u8> {
from(mut value: JoinerSecret) -> Self523         fn from(mut value: JoinerSecret) -> Self {
524             core::mem::take(&mut value.0)
525         }
526     }
527 
get_test_key_schedule(cipher_suite: CipherSuite) -> KeySchedule528     pub(crate) fn get_test_key_schedule(cipher_suite: CipherSuite) -> KeySchedule {
529         let key_size = test_cipher_suite_provider(cipher_suite).kdf_extract_size();
530         let fake_secret = Zeroizing::new(vec![1u8; key_size]);
531 
532         KeySchedule {
533             exporter_secret: fake_secret.clone(),
534             authentication_secret: fake_secret.clone(),
535             external_secret: fake_secret.clone(),
536             membership_key: fake_secret,
537             init_secret: InitSecret::new(vec![0u8; key_size]),
538         }
539     }
540 
541     impl InitSecret {
new(init_secret: Vec<u8>) -> Self542         pub fn new(init_secret: Vec<u8>) -> Self {
543             InitSecret(Zeroizing::new(init_secret))
544         }
545 
546         #[cfg(all(feature = "rfc_compliant", test, not(mls_build_async)))]
547         #[cfg_attr(coverage_nightly, coverage(off))]
random<P: CipherSuiteProvider>(cipher_suite: &P) -> Result<Self, MlsError>548         pub fn random<P: CipherSuiteProvider>(cipher_suite: &P) -> Result<Self, MlsError> {
549             cipher_suite
550                 .random_bytes_vec(cipher_suite.kdf_extract_size())
551                 .map(Zeroizing::new)
552                 .map(InitSecret)
553                 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
554         }
555     }
556 
557     #[cfg(feature = "rfc_compliant")]
558     impl KeySchedule {
set_membership_key(&mut self, key: Vec<u8>)559         pub fn set_membership_key(&mut self, key: Vec<u8>) {
560             self.membership_key = Zeroizing::new(key)
561         }
562     }
563 }
564 
565 #[cfg(test)]
566 mod tests {
567     use crate::client::test_utils::TEST_PROTOCOL_VERSION;
568     use crate::crypto::test_utils::try_test_cipher_suite_provider;
569     use crate::group::key_schedule::{
570         get_welcome_secret, kdf_derive_secret, kdf_expand_with_label,
571     };
572     use crate::group::GroupContext;
573     use alloc::string::String;
574     use alloc::vec::Vec;
575     use mls_rs_codec::MlsEncode;
576     use mls_rs_core::crypto::CipherSuiteProvider;
577     use mls_rs_core::extension::ExtensionList;
578 
579     #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
580     use crate::{
581         crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
582         group::{
583             key_schedule::KeyScheduleDerivationResult, test_utils::random_bytes, InitSecret,
584             PskSecret,
585         },
586     };
587 
588     #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
589     use alloc::{string::ToString, vec};
590 
591     #[cfg(target_arch = "wasm32")]
592     use wasm_bindgen_test::wasm_bindgen_test as test;
593     use zeroize::Zeroizing;
594 
595     use super::test_utils::get_test_key_schedule;
596     use super::KeySchedule;
597 
598     #[derive(serde::Deserialize, serde::Serialize)]
599     struct TestCase {
600         cipher_suite: u16,
601         #[serde(with = "hex::serde")]
602         group_id: Vec<u8>,
603         #[serde(with = "hex::serde")]
604         initial_init_secret: Vec<u8>,
605         epochs: Vec<KeyScheduleEpoch>,
606     }
607 
608     #[derive(serde::Deserialize, serde::Serialize)]
609     struct KeyScheduleEpoch {
610         #[serde(with = "hex::serde")]
611         commit_secret: Vec<u8>,
612         #[serde(with = "hex::serde")]
613         psk_secret: Vec<u8>,
614         #[serde(with = "hex::serde")]
615         confirmed_transcript_hash: Vec<u8>,
616         #[serde(with = "hex::serde")]
617         tree_hash: Vec<u8>,
618 
619         #[serde(with = "hex::serde")]
620         group_context: Vec<u8>,
621 
622         #[serde(with = "hex::serde")]
623         joiner_secret: Vec<u8>,
624         #[serde(with = "hex::serde")]
625         welcome_secret: Vec<u8>,
626         #[serde(with = "hex::serde")]
627         init_secret: Vec<u8>,
628 
629         #[serde(with = "hex::serde")]
630         sender_data_secret: Vec<u8>,
631         #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
632         #[serde(with = "hex::serde")]
633         encryption_secret: Vec<u8>,
634         #[serde(with = "hex::serde")]
635         exporter_secret: Vec<u8>,
636         #[serde(with = "hex::serde")]
637         epoch_authenticator: Vec<u8>,
638         #[serde(with = "hex::serde")]
639         external_secret: Vec<u8>,
640         #[serde(with = "hex::serde")]
641         confirmation_key: Vec<u8>,
642         #[serde(with = "hex::serde")]
643         membership_key: Vec<u8>,
644         #[cfg(feature = "psk")]
645         #[serde(with = "hex::serde")]
646         resumption_psk: Vec<u8>,
647 
648         #[serde(with = "hex::serde")]
649         external_pub: Vec<u8>,
650 
651         exporter: KeyScheduleExporter,
652     }
653 
654     #[derive(serde::Deserialize, serde::Serialize)]
655     struct KeyScheduleExporter {
656         label: String,
657         #[serde(with = "hex::serde")]
658         context: Vec<u8>,
659         length: usize,
660         #[serde(with = "hex::serde")]
661         secret: Vec<u8>,
662     }
663 
664     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_key_schedule()665     async fn test_key_schedule() {
666         let test_cases: Vec<TestCase> =
667             load_test_case_json!(key_schedule_test_vector, generate_test_vector());
668 
669         for test_case in test_cases {
670             let Some(cs_provider) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
671                 continue;
672             };
673 
674             let mut key_schedule = get_test_key_schedule(cs_provider.cipher_suite());
675             key_schedule.init_secret.0 = Zeroizing::new(test_case.initial_init_secret);
676 
677             for (i, epoch) in test_case.epochs.into_iter().enumerate() {
678                 let context = GroupContext {
679                     protocol_version: TEST_PROTOCOL_VERSION,
680                     cipher_suite: cs_provider.cipher_suite(),
681                     group_id: test_case.group_id.clone(),
682                     epoch: i as u64,
683                     tree_hash: epoch.tree_hash,
684                     confirmed_transcript_hash: epoch.confirmed_transcript_hash.into(),
685                     extensions: ExtensionList::new(),
686                 };
687 
688                 assert_eq!(context.mls_encode_to_vec().unwrap(), epoch.group_context);
689 
690                 let psk = epoch.psk_secret.into();
691                 let commit = epoch.commit_secret.into();
692 
693                 let key_schedule_res = KeySchedule::from_key_schedule(
694                     &key_schedule,
695                     &commit,
696                     &context,
697                     #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
698                     32,
699                     &psk,
700                     &cs_provider,
701                 )
702                 .await
703                 .unwrap();
704 
705                 key_schedule = key_schedule_res.key_schedule;
706 
707                 let welcome =
708                     get_welcome_secret(&cs_provider, &key_schedule_res.joiner_secret, &psk)
709                         .await
710                         .unwrap();
711 
712                 assert_eq!(*welcome, epoch.welcome_secret);
713 
714                 let expected: Vec<u8> = key_schedule_res.joiner_secret.into();
715                 assert_eq!(epoch.joiner_secret, expected);
716 
717                 assert_eq!(&key_schedule.init_secret.0.to_vec(), &epoch.init_secret);
718 
719                 assert_eq!(
720                     epoch.sender_data_secret,
721                     *key_schedule_res.epoch_secrets.sender_data_secret.to_vec()
722                 );
723 
724                 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
725                 assert_eq!(
726                     epoch.encryption_secret,
727                     *key_schedule_res.epoch_secrets.secret_tree.get_root_secret()
728                 );
729 
730                 assert_eq!(epoch.exporter_secret, key_schedule.exporter_secret.to_vec());
731 
732                 assert_eq!(
733                     epoch.epoch_authenticator,
734                     key_schedule.authentication_secret.to_vec()
735                 );
736 
737                 assert_eq!(epoch.external_secret, key_schedule.external_secret.to_vec());
738 
739                 assert_eq!(
740                     epoch.confirmation_key,
741                     key_schedule_res.confirmation_key.to_vec()
742                 );
743 
744                 assert_eq!(epoch.membership_key, key_schedule.membership_key.to_vec());
745 
746                 #[cfg(feature = "psk")]
747                 {
748                     let expected: Vec<u8> =
749                         key_schedule_res.epoch_secrets.resumption_secret.to_vec();
750 
751                     assert_eq!(epoch.resumption_psk, expected);
752                 }
753 
754                 let (_external_sec, external_pub) = key_schedule
755                     .get_external_key_pair(&cs_provider)
756                     .await
757                     .unwrap();
758 
759                 assert_eq!(epoch.external_pub, *external_pub);
760 
761                 let exp = epoch.exporter;
762 
763                 let exported = key_schedule
764                     .export_secret(exp.label.as_bytes(), &exp.context, exp.length, &cs_provider)
765                     .await
766                     .unwrap();
767 
768                 assert_eq!(exported.to_vec(), exp.secret);
769             }
770         }
771     }
772 
773     #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
774     #[cfg_attr(coverage_nightly, coverage(off))]
generate_test_vector() -> Vec<TestCase>775     fn generate_test_vector() -> Vec<TestCase> {
776         let mut test_cases = vec![];
777 
778         for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
779             let cs_provider = test_cipher_suite_provider(cipher_suite);
780             let key_size = cs_provider.kdf_extract_size();
781 
782             let mut group_context = GroupContext {
783                 protocol_version: TEST_PROTOCOL_VERSION,
784                 cipher_suite: cs_provider.cipher_suite(),
785                 group_id: b"my group 5".to_vec(),
786                 epoch: 0,
787                 tree_hash: random_bytes(key_size),
788                 confirmed_transcript_hash: random_bytes(key_size).into(),
789                 extensions: Default::default(),
790             };
791 
792             let initial_init_secret = InitSecret::random(&cs_provider).unwrap();
793             let mut key_schedule = get_test_key_schedule(cs_provider.cipher_suite());
794             key_schedule.init_secret = initial_init_secret.clone();
795 
796             let commit_secret = random_bytes(key_size).into();
797             let psk_secret = PskSecret::new(&cs_provider);
798 
799             let key_schedule_res = KeySchedule::from_key_schedule(
800                 &key_schedule,
801                 &commit_secret,
802                 &group_context,
803                 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
804                 32,
805                 &psk_secret,
806                 &cs_provider,
807             )
808             .unwrap();
809 
810             key_schedule = key_schedule_res.key_schedule.clone();
811 
812             let epoch1 = KeyScheduleEpoch::new(
813                 key_schedule_res,
814                 psk_secret,
815                 commit_secret.to_vec(),
816                 &group_context,
817                 &cs_provider,
818             );
819 
820             group_context.epoch += 1;
821             group_context.confirmed_transcript_hash = random_bytes(key_size).into();
822             group_context.tree_hash = random_bytes(key_size);
823 
824             let commit_secret = random_bytes(key_size).into();
825             let psk_secret = PskSecret::new(&cs_provider);
826 
827             let key_schedule_res = KeySchedule::from_key_schedule(
828                 &key_schedule,
829                 &commit_secret,
830                 &group_context,
831                 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
832                 32,
833                 &psk_secret,
834                 &cs_provider,
835             )
836             .unwrap();
837 
838             let epoch2 = KeyScheduleEpoch::new(
839                 key_schedule_res,
840                 psk_secret,
841                 commit_secret.to_vec(),
842                 &group_context,
843                 &cs_provider,
844             );
845 
846             let test_case = TestCase {
847                 cipher_suite: cs_provider.cipher_suite().into(),
848                 group_id: group_context.group_id.clone(),
849                 initial_init_secret: initial_init_secret.0.to_vec(),
850                 epochs: vec![epoch1, epoch2],
851             };
852 
853             test_cases.push(test_case);
854         }
855 
856         test_cases
857     }
858 
859     #[cfg(not(all(not(mls_build_async), feature = "rfc_compliant")))]
generate_test_vector() -> Vec<TestCase>860     fn generate_test_vector() -> Vec<TestCase> {
861         panic!("Tests cannot be generated in async mode");
862     }
863 
864     #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
865     impl KeyScheduleEpoch {
866         #[cfg_attr(coverage_nightly, coverage(off))]
new<P: CipherSuiteProvider>( key_schedule_res: KeyScheduleDerivationResult, psk_secret: PskSecret, commit_secret: Vec<u8>, group_context: &GroupContext, cs: &P, ) -> Self867         fn new<P: CipherSuiteProvider>(
868             key_schedule_res: KeyScheduleDerivationResult,
869             psk_secret: PskSecret,
870             commit_secret: Vec<u8>,
871             group_context: &GroupContext,
872             cs: &P,
873         ) -> Self {
874             let (_external_sec, external_pub) = key_schedule_res
875                 .key_schedule
876                 .get_external_key_pair(cs)
877                 .unwrap();
878 
879             let mut exporter = KeyScheduleExporter {
880                 label: "exporter label 15".to_string(),
881                 context: b"exporter context".to_vec(),
882                 length: 64,
883                 secret: vec![],
884             };
885 
886             exporter.secret = key_schedule_res
887                 .key_schedule
888                 .export_secret(
889                     exporter.label.as_bytes(),
890                     &exporter.context,
891                     exporter.length,
892                     cs,
893                 )
894                 .unwrap()
895                 .to_vec();
896 
897             let welcome_secret =
898                 get_welcome_secret(cs, &key_schedule_res.joiner_secret, &psk_secret)
899                     .unwrap()
900                     .to_vec();
901 
902             KeyScheduleEpoch {
903                 commit_secret,
904                 welcome_secret,
905                 psk_secret: psk_secret.to_vec(),
906                 group_context: group_context.mls_encode_to_vec().unwrap(),
907                 joiner_secret: key_schedule_res.joiner_secret.into(),
908                 init_secret: key_schedule_res.key_schedule.init_secret.0.to_vec(),
909                 sender_data_secret: key_schedule_res.epoch_secrets.sender_data_secret.to_vec(),
910                 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
911                 encryption_secret: key_schedule_res.epoch_secrets.secret_tree.get_root_secret(),
912                 exporter_secret: key_schedule_res.key_schedule.exporter_secret.to_vec(),
913                 epoch_authenticator: key_schedule_res.key_schedule.authentication_secret.to_vec(),
914                 external_secret: key_schedule_res.key_schedule.external_secret.to_vec(),
915                 confirmation_key: key_schedule_res.confirmation_key.to_vec(),
916                 membership_key: key_schedule_res.key_schedule.membership_key.to_vec(),
917                 #[cfg(feature = "psk")]
918                 resumption_psk: key_schedule_res.epoch_secrets.resumption_secret.to_vec(),
919                 external_pub: external_pub.to_vec(),
920                 exporter,
921                 confirmed_transcript_hash: group_context.confirmed_transcript_hash.to_vec(),
922                 tree_hash: group_context.tree_hash.clone(),
923             }
924         }
925     }
926 
927     #[derive(Debug, serde::Serialize, serde::Deserialize)]
928     struct ExpandWithLabelTestCase {
929         #[serde(with = "hex::serde")]
930         secret: Vec<u8>,
931         label: String,
932         #[serde(with = "hex::serde")]
933         context: Vec<u8>,
934         length: usize,
935         #[serde(with = "hex::serde")]
936         out: Vec<u8>,
937     }
938 
939     #[derive(Debug, serde::Serialize, serde::Deserialize)]
940     struct DeriveSecretTestCase {
941         #[serde(with = "hex::serde")]
942         secret: Vec<u8>,
943         label: String,
944         #[serde(with = "hex::serde")]
945         out: Vec<u8>,
946     }
947 
948     #[derive(Debug, serde::Serialize, serde::Deserialize)]
949     pub struct InteropTestCase {
950         cipher_suite: u16,
951         expand_with_label: ExpandWithLabelTestCase,
952         derive_secret: DeriveSecretTestCase,
953     }
954 
955     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_basic_crypto_test_vectors()956     async fn test_basic_crypto_test_vectors() {
957         // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/crypto-basics.json
958         let test_cases: Vec<InteropTestCase> =
959             load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
960 
961         for test_case in test_cases {
962             if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
963                 let test_exp = &test_case.expand_with_label;
964 
965                 let computed = kdf_expand_with_label(
966                     &cs,
967                     &test_exp.secret,
968                     test_exp.label.as_bytes(),
969                     &test_exp.context,
970                     Some(test_exp.length),
971                 )
972                 .await
973                 .unwrap();
974 
975                 assert_eq!(&computed.to_vec(), &test_exp.out);
976 
977                 let test_derive = &test_case.derive_secret;
978 
979                 let computed =
980                     kdf_derive_secret(&cs, &test_derive.secret, test_derive.label.as_bytes())
981                         .await
982                         .unwrap();
983 
984                 assert_eq!(&computed.to_vec(), &test_derive.out);
985             }
986         }
987     }
988 }
989