1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5 use crate::client::MlsError;
6 use crate::extension::ExternalPubExt;
7 use crate::group::{GroupContext, MembershipTag};
8 use crate::psk::secret::PskSecret;
9 #[cfg(feature = "psk")]
10 use crate::psk::PreSharedKey;
11 use crate::tree_kem::path_secret::PathSecret;
12 use crate::CipherSuiteProvider;
13
14 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
15 use crate::group::SecretTree;
16
17 use alloc::vec;
18 use alloc::vec::Vec;
19 use core::fmt::{self, Debug};
20 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
21 use mls_rs_core::error::IntoAnyError;
22 use zeroize::Zeroizing;
23
24 use crate::crypto::{HpkeContextR, HpkeContextS, HpkePublicKey, HpkeSecretKey};
25
26 use super::epoch::{EpochSecrets, SenderDataSecret};
27 use super::message_signature::AuthenticatedContent;
28
29 #[derive(Clone, PartialEq, Eq, Default, MlsEncode, MlsDecode, MlsSize)]
30 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31 pub struct KeySchedule {
32 #[mls_codec(with = "mls_rs_codec::byte_vec")]
33 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
34 exporter_secret: Zeroizing<Vec<u8>>,
35 #[mls_codec(with = "mls_rs_codec::byte_vec")]
36 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
37 pub authentication_secret: Zeroizing<Vec<u8>>,
38 #[mls_codec(with = "mls_rs_codec::byte_vec")]
39 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
40 external_secret: Zeroizing<Vec<u8>>,
41 #[mls_codec(with = "mls_rs_codec::byte_vec")]
42 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
43 membership_key: Zeroizing<Vec<u8>>,
44 init_secret: InitSecret,
45 }
46
47 impl Debug for KeySchedule {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 f.debug_struct("KeySchedule")
50 .field(
51 "exporter_secret",
52 &mls_rs_core::debug::pretty_bytes(&self.exporter_secret),
53 )
54 .field(
55 "authentication_secret",
56 &mls_rs_core::debug::pretty_bytes(&self.authentication_secret),
57 )
58 .field(
59 "external_secret",
60 &mls_rs_core::debug::pretty_bytes(&self.external_secret),
61 )
62 .field(
63 "membership_key",
64 &mls_rs_core::debug::pretty_bytes(&self.membership_key),
65 )
66 .field("init_secret", &self.init_secret)
67 .finish()
68 }
69 }
70
71 pub(crate) struct KeyScheduleDerivationResult {
72 pub(crate) key_schedule: KeySchedule,
73 pub(crate) confirmation_key: Zeroizing<Vec<u8>>,
74 pub(crate) joiner_secret: JoinerSecret,
75 pub(crate) epoch_secrets: EpochSecrets,
76 }
77
78 impl KeySchedule {
new(init_secret: InitSecret) -> Self79 pub fn new(init_secret: InitSecret) -> Self {
80 KeySchedule {
81 init_secret,
82 ..Default::default()
83 }
84 }
85
86 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
derive_for_external<P: CipherSuiteProvider>( &self, kem_output: &[u8], cipher_suite: &P, ) -> Result<KeySchedule, MlsError>87 pub async fn derive_for_external<P: CipherSuiteProvider>(
88 &self,
89 kem_output: &[u8],
90 cipher_suite: &P,
91 ) -> Result<KeySchedule, MlsError> {
92 let (secret, public) = self.get_external_key_pair(cipher_suite).await?;
93
94 let init_secret =
95 InitSecret::decode_for_external(cipher_suite, kem_output, &secret, &public).await?;
96
97 Ok(KeySchedule::new(init_secret))
98 }
99
100 /// Returns the derived epoch as well as the joiner secret required for building welcome
101 /// messages
102 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
from_key_schedule<P: CipherSuiteProvider>( last_key_schedule: &KeySchedule, commit_secret: &PathSecret, context: &GroupContext, #[cfg(any(feature = "secret_tree_access", feature = "private_message"))] secret_tree_size: u32, psk_secret: &PskSecret, cipher_suite_provider: &P, ) -> Result<KeyScheduleDerivationResult, MlsError>103 pub(crate) async fn from_key_schedule<P: CipherSuiteProvider>(
104 last_key_schedule: &KeySchedule,
105 commit_secret: &PathSecret,
106 context: &GroupContext,
107 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
108 secret_tree_size: u32,
109 psk_secret: &PskSecret,
110 cipher_suite_provider: &P,
111 ) -> Result<KeyScheduleDerivationResult, MlsError> {
112 let joiner_seed = cipher_suite_provider
113 .kdf_extract(&last_key_schedule.init_secret.0, commit_secret)
114 .await
115 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
116
117 let joiner_secret = kdf_expand_with_label(
118 cipher_suite_provider,
119 &joiner_seed,
120 b"joiner",
121 &context.mls_encode_to_vec()?,
122 None,
123 )
124 .await?
125 .into();
126
127 let key_schedule_result = Self::from_joiner(
128 cipher_suite_provider,
129 &joiner_secret,
130 context,
131 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
132 secret_tree_size,
133 psk_secret,
134 )
135 .await?;
136
137 Ok(KeyScheduleDerivationResult {
138 key_schedule: key_schedule_result.key_schedule,
139 confirmation_key: key_schedule_result.confirmation_key,
140 joiner_secret,
141 epoch_secrets: key_schedule_result.epoch_secrets,
142 })
143 }
144
145 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
from_joiner<P: CipherSuiteProvider>( cipher_suite_provider: &P, joiner_secret: &JoinerSecret, context: &GroupContext, #[cfg(any(feature = "secret_tree_access", feature = "private_message"))] secret_tree_size: u32, psk_secret: &PskSecret, ) -> Result<KeyScheduleDerivationResult, MlsError>146 pub(crate) async fn from_joiner<P: CipherSuiteProvider>(
147 cipher_suite_provider: &P,
148 joiner_secret: &JoinerSecret,
149 context: &GroupContext,
150 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
151 secret_tree_size: u32,
152 psk_secret: &PskSecret,
153 ) -> Result<KeyScheduleDerivationResult, MlsError> {
154 let epoch_seed =
155 get_pre_epoch_secret(cipher_suite_provider, psk_secret, joiner_secret).await?;
156 let context = context.mls_encode_to_vec()?;
157
158 let epoch_secret =
159 kdf_expand_with_label(cipher_suite_provider, &epoch_seed, b"epoch", &context, None)
160 .await?;
161
162 Self::from_epoch_secret(
163 cipher_suite_provider,
164 &epoch_secret,
165 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
166 secret_tree_size,
167 )
168 .await
169 }
170
171 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
from_random_epoch_secret<P: CipherSuiteProvider>( cipher_suite_provider: &P, #[cfg(any(feature = "secret_tree_access", feature = "private_message"))] secret_tree_size: u32, ) -> Result<KeyScheduleDerivationResult, MlsError>172 pub(crate) async fn from_random_epoch_secret<P: CipherSuiteProvider>(
173 cipher_suite_provider: &P,
174 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
175 secret_tree_size: u32,
176 ) -> Result<KeyScheduleDerivationResult, MlsError> {
177 let epoch_secret = cipher_suite_provider
178 .random_bytes_vec(cipher_suite_provider.kdf_extract_size())
179 .map(Zeroizing::new)
180 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
181
182 Self::from_epoch_secret(
183 cipher_suite_provider,
184 &epoch_secret,
185 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
186 secret_tree_size,
187 )
188 .await
189 }
190
191 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
from_epoch_secret<P: CipherSuiteProvider>( cipher_suite_provider: &P, epoch_secret: &[u8], #[cfg(any(feature = "secret_tree_access", feature = "private_message"))] secret_tree_size: u32, ) -> Result<KeyScheduleDerivationResult, MlsError>192 async fn from_epoch_secret<P: CipherSuiteProvider>(
193 cipher_suite_provider: &P,
194 epoch_secret: &[u8],
195 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
196 secret_tree_size: u32,
197 ) -> Result<KeyScheduleDerivationResult, MlsError> {
198 let secrets_producer = SecretsProducer::new(cipher_suite_provider, epoch_secret);
199
200 let epoch_secrets = EpochSecrets {
201 #[cfg(feature = "psk")]
202 resumption_secret: PreSharedKey::from(secrets_producer.derive(b"resumption").await?),
203 sender_data_secret: SenderDataSecret::from(
204 secrets_producer.derive(b"sender data").await?,
205 ),
206 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
207 secret_tree: SecretTree::new(
208 secret_tree_size,
209 secrets_producer.derive(b"encryption").await?,
210 ),
211 };
212
213 let key_schedule = Self {
214 exporter_secret: secrets_producer.derive(b"exporter").await?,
215 authentication_secret: secrets_producer.derive(b"authentication").await?,
216 external_secret: secrets_producer.derive(b"external").await?,
217 membership_key: secrets_producer.derive(b"membership").await?,
218 init_secret: InitSecret(secrets_producer.derive(b"init").await?),
219 };
220
221 Ok(KeyScheduleDerivationResult {
222 key_schedule,
223 confirmation_key: secrets_producer.derive(b"confirm").await?,
224 joiner_secret: Zeroizing::new(vec![]).into(),
225 epoch_secrets,
226 })
227 }
228
229 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
export_secret<P: CipherSuiteProvider>( &self, label: &[u8], context: &[u8], len: usize, cipher_suite: &P, ) -> Result<Zeroizing<Vec<u8>>, MlsError>230 pub async fn export_secret<P: CipherSuiteProvider>(
231 &self,
232 label: &[u8],
233 context: &[u8],
234 len: usize,
235 cipher_suite: &P,
236 ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
237 let secret = kdf_derive_secret(cipher_suite, &self.exporter_secret, label).await?;
238
239 let context_hash = cipher_suite
240 .hash(context)
241 .await
242 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
243
244 kdf_expand_with_label(cipher_suite, &secret, b"exported", &context_hash, Some(len)).await
245 }
246
247 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_membership_tag<P: CipherSuiteProvider>( &self, content: &AuthenticatedContent, context: &GroupContext, cipher_suite_provider: &P, ) -> Result<MembershipTag, MlsError>248 pub async fn get_membership_tag<P: CipherSuiteProvider>(
249 &self,
250 content: &AuthenticatedContent,
251 context: &GroupContext,
252 cipher_suite_provider: &P,
253 ) -> Result<MembershipTag, MlsError> {
254 MembershipTag::create(
255 content,
256 context,
257 &self.membership_key,
258 cipher_suite_provider,
259 )
260 .await
261 }
262
263 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_external_key_pair<P: CipherSuiteProvider>( &self, cipher_suite: &P, ) -> Result<(HpkeSecretKey, HpkePublicKey), MlsError>264 pub async fn get_external_key_pair<P: CipherSuiteProvider>(
265 &self,
266 cipher_suite: &P,
267 ) -> Result<(HpkeSecretKey, HpkePublicKey), MlsError> {
268 cipher_suite
269 .kem_derive(&self.external_secret)
270 .await
271 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
272 }
273
274 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_external_key_pair_ext<P: CipherSuiteProvider>( &self, cipher_suite: &P, ) -> Result<ExternalPubExt, MlsError>275 pub async fn get_external_key_pair_ext<P: CipherSuiteProvider>(
276 &self,
277 cipher_suite: &P,
278 ) -> Result<ExternalPubExt, MlsError> {
279 let (_external_secret, external_pub) = self.get_external_key_pair(cipher_suite).await?;
280
281 Ok(ExternalPubExt { external_pub })
282 }
283 }
284
285 #[derive(MlsEncode, MlsSize)]
286 struct Label<'a> {
287 length: u16,
288 #[mls_codec(with = "mls_rs_codec::byte_vec")]
289 label: Vec<u8>,
290 #[mls_codec(with = "mls_rs_codec::byte_vec")]
291 context: &'a [u8],
292 }
293
294 impl<'a> Label<'a> {
new(length: u16, label: &'a [u8], context: &'a [u8]) -> Self295 fn new(length: u16, label: &'a [u8], context: &'a [u8]) -> Self {
296 Self {
297 length,
298 label: [b"MLS 1.0 ", label].concat(),
299 context,
300 }
301 }
302 }
303
304 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
kdf_expand_with_label<P: CipherSuiteProvider>( cipher_suite_provider: &P, secret: &[u8], label: &[u8], context: &[u8], len: Option<usize>, ) -> Result<Zeroizing<Vec<u8>>, MlsError>305 pub(crate) async fn kdf_expand_with_label<P: CipherSuiteProvider>(
306 cipher_suite_provider: &P,
307 secret: &[u8],
308 label: &[u8],
309 context: &[u8],
310 len: Option<usize>,
311 ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
312 let extract_size = cipher_suite_provider.kdf_extract_size();
313 let len = len.unwrap_or(extract_size);
314 let label = Label::new(len as u16, label, context);
315
316 cipher_suite_provider
317 .kdf_expand(secret, &label.mls_encode_to_vec()?, len)
318 .await
319 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
320 }
321
322 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
kdf_derive_secret<P: CipherSuiteProvider>( cipher_suite_provider: &P, secret: &[u8], label: &[u8], ) -> Result<Zeroizing<Vec<u8>>, MlsError>323 pub(crate) async fn kdf_derive_secret<P: CipherSuiteProvider>(
324 cipher_suite_provider: &P,
325 secret: &[u8],
326 label: &[u8],
327 ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
328 kdf_expand_with_label(cipher_suite_provider, secret, label, &[], None).await
329 }
330
331 #[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
332 pub(crate) struct JoinerSecret(#[mls_codec(with = "mls_rs_codec::byte_vec")] Zeroizing<Vec<u8>>);
333
334 impl Debug for JoinerSecret {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result335 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
336 mls_rs_core::debug::pretty_bytes(&self.0)
337 .named("JoinerSecret")
338 .fmt(f)
339 }
340 }
341
342 impl From<Zeroizing<Vec<u8>>> for JoinerSecret {
from(bytes: Zeroizing<Vec<u8>>) -> Self343 fn from(bytes: Zeroizing<Vec<u8>>) -> Self {
344 Self(bytes)
345 }
346 }
347
348 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_pre_epoch_secret<P: CipherSuiteProvider>( cipher_suite_provider: &P, psk_secret: &PskSecret, joiner_secret: &JoinerSecret, ) -> Result<Zeroizing<Vec<u8>>, MlsError>349 pub(crate) async fn get_pre_epoch_secret<P: CipherSuiteProvider>(
350 cipher_suite_provider: &P,
351 psk_secret: &PskSecret,
352 joiner_secret: &JoinerSecret,
353 ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
354 cipher_suite_provider
355 .kdf_extract(&joiner_secret.0, psk_secret)
356 .await
357 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
358 }
359
360 struct SecretsProducer<'a, P: CipherSuiteProvider> {
361 cipher_suite_provider: &'a P,
362 epoch_secret: &'a [u8],
363 }
364
365 impl<'a, P: CipherSuiteProvider> SecretsProducer<'a, P> {
new(cipher_suite_provider: &'a P, epoch_secret: &'a [u8]) -> Self366 fn new(cipher_suite_provider: &'a P, epoch_secret: &'a [u8]) -> Self {
367 Self {
368 cipher_suite_provider,
369 epoch_secret,
370 }
371 }
372
373 // TODO document somewhere in the crypto provider that the RFC defines the length of all secrets as
374 // KDF extract size but then inputs secrets as MAC keys etc, therefore, we require that these
375 // lengths match in the crypto provider
376 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
derive(&self, label: &[u8]) -> Result<Zeroizing<Vec<u8>>, MlsError>377 async fn derive(&self, label: &[u8]) -> Result<Zeroizing<Vec<u8>>, MlsError> {
378 kdf_derive_secret(self.cipher_suite_provider, self.epoch_secret, label).await
379 }
380 }
381
382 const EXPORTER_CONTEXT: &[u8] = b"MLS 1.0 external init secret";
383
384 #[derive(Clone, Eq, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)]
385 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
386 pub struct InitSecret(
387 #[mls_codec(with = "mls_rs_codec::byte_vec")]
388 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
389 Zeroizing<Vec<u8>>,
390 );
391
392 impl Debug for InitSecret {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result393 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
394 mls_rs_core::debug::pretty_bytes(&self.0)
395 .named("InitSecret")
396 .fmt(f)
397 }
398 }
399
400 impl InitSecret {
401 /// Returns init secret and KEM output to be used when creating an external commit.
402 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
encode_for_external<P: CipherSuiteProvider>( cipher_suite: &P, external_pub: &HpkePublicKey, ) -> Result<(Self, Vec<u8>), MlsError>403 pub async fn encode_for_external<P: CipherSuiteProvider>(
404 cipher_suite: &P,
405 external_pub: &HpkePublicKey,
406 ) -> Result<(Self, Vec<u8>), MlsError> {
407 let (kem_output, context) = cipher_suite
408 .hpke_setup_s(external_pub, &[])
409 .await
410 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
411
412 let init_secret = context
413 .export(EXPORTER_CONTEXT, cipher_suite.kdf_extract_size())
414 .await
415 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
416
417 Ok((InitSecret(Zeroizing::new(init_secret)), kem_output))
418 }
419
420 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
decode_for_external<P: CipherSuiteProvider>( cipher_suite: &P, kem_output: &[u8], external_secret: &HpkeSecretKey, external_pub: &HpkePublicKey, ) -> Result<Self, MlsError>421 pub async fn decode_for_external<P: CipherSuiteProvider>(
422 cipher_suite: &P,
423 kem_output: &[u8],
424 external_secret: &HpkeSecretKey,
425 external_pub: &HpkePublicKey,
426 ) -> Result<Self, MlsError> {
427 let context = cipher_suite
428 .hpke_setup_r(kem_output, external_secret, external_pub, &[])
429 .await
430 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
431
432 context
433 .export(EXPORTER_CONTEXT, cipher_suite.kdf_extract_size())
434 .await
435 .map(Zeroizing::new)
436 .map(InitSecret)
437 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
438 }
439 }
440
441 pub(crate) struct WelcomeSecret<'a, P: CipherSuiteProvider> {
442 cipher_suite: &'a P,
443 key: Zeroizing<Vec<u8>>,
444 nonce: Zeroizing<Vec<u8>>,
445 }
446
447 impl<'a, P: CipherSuiteProvider> WelcomeSecret<'a, P> {
448 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
from_joiner_secret( cipher_suite: &'a P, joiner_secret: &JoinerSecret, psk_secret: &PskSecret, ) -> Result<WelcomeSecret<'a, P>, MlsError>449 pub(crate) async fn from_joiner_secret(
450 cipher_suite: &'a P,
451 joiner_secret: &JoinerSecret,
452 psk_secret: &PskSecret,
453 ) -> Result<WelcomeSecret<'a, P>, MlsError> {
454 let welcome_secret = get_welcome_secret(cipher_suite, joiner_secret, psk_secret).await?;
455
456 let key_len = cipher_suite.aead_key_size();
457 let key = kdf_expand_with_label(cipher_suite, &welcome_secret, b"key", &[], Some(key_len))
458 .await?;
459
460 let nonce_len = cipher_suite.aead_nonce_size();
461
462 let nonce = kdf_expand_with_label(
463 cipher_suite,
464 &welcome_secret,
465 b"nonce",
466 &[],
467 Some(nonce_len),
468 )
469 .await?;
470
471 Ok(Self {
472 cipher_suite,
473 key,
474 nonce,
475 })
476 }
477
478 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, MlsError>479 pub(crate) async fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, MlsError> {
480 self.cipher_suite
481 .aead_seal(&self.key, plaintext, None, &self.nonce)
482 .await
483 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
484 }
485
486 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
decrypt(&self, ciphertext: &[u8]) -> Result<Zeroizing<Vec<u8>>, MlsError>487 pub(crate) async fn decrypt(&self, ciphertext: &[u8]) -> Result<Zeroizing<Vec<u8>>, MlsError> {
488 self.cipher_suite
489 .aead_open(&self.key, ciphertext, None, &self.nonce)
490 .await
491 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
492 }
493 }
494
495 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_welcome_secret<P: CipherSuiteProvider>( cipher_suite: &P, joiner_secret: &JoinerSecret, psk_secret: &PskSecret, ) -> Result<Zeroizing<Vec<u8>>, MlsError>496 async fn get_welcome_secret<P: CipherSuiteProvider>(
497 cipher_suite: &P,
498 joiner_secret: &JoinerSecret,
499 psk_secret: &PskSecret,
500 ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
501 let epoch_seed = get_pre_epoch_secret(cipher_suite, psk_secret, joiner_secret).await?;
502 kdf_derive_secret(cipher_suite, &epoch_seed, b"welcome").await
503 }
504
505 #[cfg(test)]
506 pub(crate) mod test_utils {
507 use alloc::vec;
508 use alloc::vec::Vec;
509 use mls_rs_core::crypto::CipherSuiteProvider;
510 use zeroize::Zeroizing;
511
512 use crate::{cipher_suite::CipherSuite, crypto::test_utils::test_cipher_suite_provider};
513
514 use super::{InitSecret, JoinerSecret, KeySchedule};
515
516 #[cfg(all(feature = "rfc_compliant", not(mls_build_async)))]
517 use mls_rs_core::error::IntoAnyError;
518
519 #[cfg(all(feature = "rfc_compliant", not(mls_build_async)))]
520 use super::MlsError;
521
522 impl From<JoinerSecret> for Vec<u8> {
from(mut value: JoinerSecret) -> Self523 fn from(mut value: JoinerSecret) -> Self {
524 core::mem::take(&mut value.0)
525 }
526 }
527
get_test_key_schedule(cipher_suite: CipherSuite) -> KeySchedule528 pub(crate) fn get_test_key_schedule(cipher_suite: CipherSuite) -> KeySchedule {
529 let key_size = test_cipher_suite_provider(cipher_suite).kdf_extract_size();
530 let fake_secret = Zeroizing::new(vec![1u8; key_size]);
531
532 KeySchedule {
533 exporter_secret: fake_secret.clone(),
534 authentication_secret: fake_secret.clone(),
535 external_secret: fake_secret.clone(),
536 membership_key: fake_secret,
537 init_secret: InitSecret::new(vec![0u8; key_size]),
538 }
539 }
540
541 impl InitSecret {
new(init_secret: Vec<u8>) -> Self542 pub fn new(init_secret: Vec<u8>) -> Self {
543 InitSecret(Zeroizing::new(init_secret))
544 }
545
546 #[cfg(all(feature = "rfc_compliant", test, not(mls_build_async)))]
547 #[cfg_attr(coverage_nightly, coverage(off))]
random<P: CipherSuiteProvider>(cipher_suite: &P) -> Result<Self, MlsError>548 pub fn random<P: CipherSuiteProvider>(cipher_suite: &P) -> Result<Self, MlsError> {
549 cipher_suite
550 .random_bytes_vec(cipher_suite.kdf_extract_size())
551 .map(Zeroizing::new)
552 .map(InitSecret)
553 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
554 }
555 }
556
557 #[cfg(feature = "rfc_compliant")]
558 impl KeySchedule {
set_membership_key(&mut self, key: Vec<u8>)559 pub fn set_membership_key(&mut self, key: Vec<u8>) {
560 self.membership_key = Zeroizing::new(key)
561 }
562 }
563 }
564
565 #[cfg(test)]
566 mod tests {
567 use crate::client::test_utils::TEST_PROTOCOL_VERSION;
568 use crate::crypto::test_utils::try_test_cipher_suite_provider;
569 use crate::group::key_schedule::{
570 get_welcome_secret, kdf_derive_secret, kdf_expand_with_label,
571 };
572 use crate::group::GroupContext;
573 use alloc::string::String;
574 use alloc::vec::Vec;
575 use mls_rs_codec::MlsEncode;
576 use mls_rs_core::crypto::CipherSuiteProvider;
577 use mls_rs_core::extension::ExtensionList;
578
579 #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
580 use crate::{
581 crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
582 group::{
583 key_schedule::KeyScheduleDerivationResult, test_utils::random_bytes, InitSecret,
584 PskSecret,
585 },
586 };
587
588 #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
589 use alloc::{string::ToString, vec};
590
591 #[cfg(target_arch = "wasm32")]
592 use wasm_bindgen_test::wasm_bindgen_test as test;
593 use zeroize::Zeroizing;
594
595 use super::test_utils::get_test_key_schedule;
596 use super::KeySchedule;
597
598 #[derive(serde::Deserialize, serde::Serialize)]
599 struct TestCase {
600 cipher_suite: u16,
601 #[serde(with = "hex::serde")]
602 group_id: Vec<u8>,
603 #[serde(with = "hex::serde")]
604 initial_init_secret: Vec<u8>,
605 epochs: Vec<KeyScheduleEpoch>,
606 }
607
608 #[derive(serde::Deserialize, serde::Serialize)]
609 struct KeyScheduleEpoch {
610 #[serde(with = "hex::serde")]
611 commit_secret: Vec<u8>,
612 #[serde(with = "hex::serde")]
613 psk_secret: Vec<u8>,
614 #[serde(with = "hex::serde")]
615 confirmed_transcript_hash: Vec<u8>,
616 #[serde(with = "hex::serde")]
617 tree_hash: Vec<u8>,
618
619 #[serde(with = "hex::serde")]
620 group_context: Vec<u8>,
621
622 #[serde(with = "hex::serde")]
623 joiner_secret: Vec<u8>,
624 #[serde(with = "hex::serde")]
625 welcome_secret: Vec<u8>,
626 #[serde(with = "hex::serde")]
627 init_secret: Vec<u8>,
628
629 #[serde(with = "hex::serde")]
630 sender_data_secret: Vec<u8>,
631 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
632 #[serde(with = "hex::serde")]
633 encryption_secret: Vec<u8>,
634 #[serde(with = "hex::serde")]
635 exporter_secret: Vec<u8>,
636 #[serde(with = "hex::serde")]
637 epoch_authenticator: Vec<u8>,
638 #[serde(with = "hex::serde")]
639 external_secret: Vec<u8>,
640 #[serde(with = "hex::serde")]
641 confirmation_key: Vec<u8>,
642 #[serde(with = "hex::serde")]
643 membership_key: Vec<u8>,
644 #[cfg(feature = "psk")]
645 #[serde(with = "hex::serde")]
646 resumption_psk: Vec<u8>,
647
648 #[serde(with = "hex::serde")]
649 external_pub: Vec<u8>,
650
651 exporter: KeyScheduleExporter,
652 }
653
654 #[derive(serde::Deserialize, serde::Serialize)]
655 struct KeyScheduleExporter {
656 label: String,
657 #[serde(with = "hex::serde")]
658 context: Vec<u8>,
659 length: usize,
660 #[serde(with = "hex::serde")]
661 secret: Vec<u8>,
662 }
663
664 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_key_schedule()665 async fn test_key_schedule() {
666 let test_cases: Vec<TestCase> =
667 load_test_case_json!(key_schedule_test_vector, generate_test_vector());
668
669 for test_case in test_cases {
670 let Some(cs_provider) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
671 continue;
672 };
673
674 let mut key_schedule = get_test_key_schedule(cs_provider.cipher_suite());
675 key_schedule.init_secret.0 = Zeroizing::new(test_case.initial_init_secret);
676
677 for (i, epoch) in test_case.epochs.into_iter().enumerate() {
678 let context = GroupContext {
679 protocol_version: TEST_PROTOCOL_VERSION,
680 cipher_suite: cs_provider.cipher_suite(),
681 group_id: test_case.group_id.clone(),
682 epoch: i as u64,
683 tree_hash: epoch.tree_hash,
684 confirmed_transcript_hash: epoch.confirmed_transcript_hash.into(),
685 extensions: ExtensionList::new(),
686 };
687
688 assert_eq!(context.mls_encode_to_vec().unwrap(), epoch.group_context);
689
690 let psk = epoch.psk_secret.into();
691 let commit = epoch.commit_secret.into();
692
693 let key_schedule_res = KeySchedule::from_key_schedule(
694 &key_schedule,
695 &commit,
696 &context,
697 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
698 32,
699 &psk,
700 &cs_provider,
701 )
702 .await
703 .unwrap();
704
705 key_schedule = key_schedule_res.key_schedule;
706
707 let welcome =
708 get_welcome_secret(&cs_provider, &key_schedule_res.joiner_secret, &psk)
709 .await
710 .unwrap();
711
712 assert_eq!(*welcome, epoch.welcome_secret);
713
714 let expected: Vec<u8> = key_schedule_res.joiner_secret.into();
715 assert_eq!(epoch.joiner_secret, expected);
716
717 assert_eq!(&key_schedule.init_secret.0.to_vec(), &epoch.init_secret);
718
719 assert_eq!(
720 epoch.sender_data_secret,
721 *key_schedule_res.epoch_secrets.sender_data_secret.to_vec()
722 );
723
724 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
725 assert_eq!(
726 epoch.encryption_secret,
727 *key_schedule_res.epoch_secrets.secret_tree.get_root_secret()
728 );
729
730 assert_eq!(epoch.exporter_secret, key_schedule.exporter_secret.to_vec());
731
732 assert_eq!(
733 epoch.epoch_authenticator,
734 key_schedule.authentication_secret.to_vec()
735 );
736
737 assert_eq!(epoch.external_secret, key_schedule.external_secret.to_vec());
738
739 assert_eq!(
740 epoch.confirmation_key,
741 key_schedule_res.confirmation_key.to_vec()
742 );
743
744 assert_eq!(epoch.membership_key, key_schedule.membership_key.to_vec());
745
746 #[cfg(feature = "psk")]
747 {
748 let expected: Vec<u8> =
749 key_schedule_res.epoch_secrets.resumption_secret.to_vec();
750
751 assert_eq!(epoch.resumption_psk, expected);
752 }
753
754 let (_external_sec, external_pub) = key_schedule
755 .get_external_key_pair(&cs_provider)
756 .await
757 .unwrap();
758
759 assert_eq!(epoch.external_pub, *external_pub);
760
761 let exp = epoch.exporter;
762
763 let exported = key_schedule
764 .export_secret(exp.label.as_bytes(), &exp.context, exp.length, &cs_provider)
765 .await
766 .unwrap();
767
768 assert_eq!(exported.to_vec(), exp.secret);
769 }
770 }
771 }
772
773 #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
774 #[cfg_attr(coverage_nightly, coverage(off))]
generate_test_vector() -> Vec<TestCase>775 fn generate_test_vector() -> Vec<TestCase> {
776 let mut test_cases = vec![];
777
778 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
779 let cs_provider = test_cipher_suite_provider(cipher_suite);
780 let key_size = cs_provider.kdf_extract_size();
781
782 let mut group_context = GroupContext {
783 protocol_version: TEST_PROTOCOL_VERSION,
784 cipher_suite: cs_provider.cipher_suite(),
785 group_id: b"my group 5".to_vec(),
786 epoch: 0,
787 tree_hash: random_bytes(key_size),
788 confirmed_transcript_hash: random_bytes(key_size).into(),
789 extensions: Default::default(),
790 };
791
792 let initial_init_secret = InitSecret::random(&cs_provider).unwrap();
793 let mut key_schedule = get_test_key_schedule(cs_provider.cipher_suite());
794 key_schedule.init_secret = initial_init_secret.clone();
795
796 let commit_secret = random_bytes(key_size).into();
797 let psk_secret = PskSecret::new(&cs_provider);
798
799 let key_schedule_res = KeySchedule::from_key_schedule(
800 &key_schedule,
801 &commit_secret,
802 &group_context,
803 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
804 32,
805 &psk_secret,
806 &cs_provider,
807 )
808 .unwrap();
809
810 key_schedule = key_schedule_res.key_schedule.clone();
811
812 let epoch1 = KeyScheduleEpoch::new(
813 key_schedule_res,
814 psk_secret,
815 commit_secret.to_vec(),
816 &group_context,
817 &cs_provider,
818 );
819
820 group_context.epoch += 1;
821 group_context.confirmed_transcript_hash = random_bytes(key_size).into();
822 group_context.tree_hash = random_bytes(key_size);
823
824 let commit_secret = random_bytes(key_size).into();
825 let psk_secret = PskSecret::new(&cs_provider);
826
827 let key_schedule_res = KeySchedule::from_key_schedule(
828 &key_schedule,
829 &commit_secret,
830 &group_context,
831 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
832 32,
833 &psk_secret,
834 &cs_provider,
835 )
836 .unwrap();
837
838 let epoch2 = KeyScheduleEpoch::new(
839 key_schedule_res,
840 psk_secret,
841 commit_secret.to_vec(),
842 &group_context,
843 &cs_provider,
844 );
845
846 let test_case = TestCase {
847 cipher_suite: cs_provider.cipher_suite().into(),
848 group_id: group_context.group_id.clone(),
849 initial_init_secret: initial_init_secret.0.to_vec(),
850 epochs: vec![epoch1, epoch2],
851 };
852
853 test_cases.push(test_case);
854 }
855
856 test_cases
857 }
858
859 #[cfg(not(all(not(mls_build_async), feature = "rfc_compliant")))]
generate_test_vector() -> Vec<TestCase>860 fn generate_test_vector() -> Vec<TestCase> {
861 panic!("Tests cannot be generated in async mode");
862 }
863
864 #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
865 impl KeyScheduleEpoch {
866 #[cfg_attr(coverage_nightly, coverage(off))]
new<P: CipherSuiteProvider>( key_schedule_res: KeyScheduleDerivationResult, psk_secret: PskSecret, commit_secret: Vec<u8>, group_context: &GroupContext, cs: &P, ) -> Self867 fn new<P: CipherSuiteProvider>(
868 key_schedule_res: KeyScheduleDerivationResult,
869 psk_secret: PskSecret,
870 commit_secret: Vec<u8>,
871 group_context: &GroupContext,
872 cs: &P,
873 ) -> Self {
874 let (_external_sec, external_pub) = key_schedule_res
875 .key_schedule
876 .get_external_key_pair(cs)
877 .unwrap();
878
879 let mut exporter = KeyScheduleExporter {
880 label: "exporter label 15".to_string(),
881 context: b"exporter context".to_vec(),
882 length: 64,
883 secret: vec![],
884 };
885
886 exporter.secret = key_schedule_res
887 .key_schedule
888 .export_secret(
889 exporter.label.as_bytes(),
890 &exporter.context,
891 exporter.length,
892 cs,
893 )
894 .unwrap()
895 .to_vec();
896
897 let welcome_secret =
898 get_welcome_secret(cs, &key_schedule_res.joiner_secret, &psk_secret)
899 .unwrap()
900 .to_vec();
901
902 KeyScheduleEpoch {
903 commit_secret,
904 welcome_secret,
905 psk_secret: psk_secret.to_vec(),
906 group_context: group_context.mls_encode_to_vec().unwrap(),
907 joiner_secret: key_schedule_res.joiner_secret.into(),
908 init_secret: key_schedule_res.key_schedule.init_secret.0.to_vec(),
909 sender_data_secret: key_schedule_res.epoch_secrets.sender_data_secret.to_vec(),
910 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
911 encryption_secret: key_schedule_res.epoch_secrets.secret_tree.get_root_secret(),
912 exporter_secret: key_schedule_res.key_schedule.exporter_secret.to_vec(),
913 epoch_authenticator: key_schedule_res.key_schedule.authentication_secret.to_vec(),
914 external_secret: key_schedule_res.key_schedule.external_secret.to_vec(),
915 confirmation_key: key_schedule_res.confirmation_key.to_vec(),
916 membership_key: key_schedule_res.key_schedule.membership_key.to_vec(),
917 #[cfg(feature = "psk")]
918 resumption_psk: key_schedule_res.epoch_secrets.resumption_secret.to_vec(),
919 external_pub: external_pub.to_vec(),
920 exporter,
921 confirmed_transcript_hash: group_context.confirmed_transcript_hash.to_vec(),
922 tree_hash: group_context.tree_hash.clone(),
923 }
924 }
925 }
926
927 #[derive(Debug, serde::Serialize, serde::Deserialize)]
928 struct ExpandWithLabelTestCase {
929 #[serde(with = "hex::serde")]
930 secret: Vec<u8>,
931 label: String,
932 #[serde(with = "hex::serde")]
933 context: Vec<u8>,
934 length: usize,
935 #[serde(with = "hex::serde")]
936 out: Vec<u8>,
937 }
938
939 #[derive(Debug, serde::Serialize, serde::Deserialize)]
940 struct DeriveSecretTestCase {
941 #[serde(with = "hex::serde")]
942 secret: Vec<u8>,
943 label: String,
944 #[serde(with = "hex::serde")]
945 out: Vec<u8>,
946 }
947
948 #[derive(Debug, serde::Serialize, serde::Deserialize)]
949 pub struct InteropTestCase {
950 cipher_suite: u16,
951 expand_with_label: ExpandWithLabelTestCase,
952 derive_secret: DeriveSecretTestCase,
953 }
954
955 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_basic_crypto_test_vectors()956 async fn test_basic_crypto_test_vectors() {
957 // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/crypto-basics.json
958 let test_cases: Vec<InteropTestCase> =
959 load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
960
961 for test_case in test_cases {
962 if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
963 let test_exp = &test_case.expand_with_label;
964
965 let computed = kdf_expand_with_label(
966 &cs,
967 &test_exp.secret,
968 test_exp.label.as_bytes(),
969 &test_exp.context,
970 Some(test_exp.length),
971 )
972 .await
973 .unwrap();
974
975 assert_eq!(&computed.to_vec(), &test_exp.out);
976
977 let test_derive = &test_case.derive_secret;
978
979 let computed =
980 kdf_derive_secret(&cs, &test_derive.secret, test_derive.label.as_bytes())
981 .await
982 .unwrap();
983
984 assert_eq!(&computed.to_vec(), &test_derive.out);
985 }
986 }
987 }
988 }
989