1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5 use core::ops::{Deref, DerefMut};
6
7 use alloc::format;
8 use rand::RngCore;
9
10 use super::*;
11 use crate::{
12 client::{
13 test_utils::{
14 test_client_with_key_pkg, test_client_with_key_pkg_custom, TEST_CIPHER_SUITE,
15 TEST_PROTOCOL_VERSION,
16 },
17 MlsError,
18 },
19 client_builder::test_utils::{TestClientBuilder, TestClientConfig},
20 crypto::test_utils::test_cipher_suite_provider,
21 extension::ExtensionType,
22 identity::basic::BasicIdentityProvider,
23 identity::test_utils::get_test_signing_identity,
24 key_package::{KeyPackageGeneration, KeyPackageGenerator},
25 mls_rules::{CommitOptions, DefaultMlsRules},
26 tree_kem::{leaf_node::test_utils::get_test_capabilities, Lifetime},
27 };
28
29 use crate::extension::RequiredCapabilitiesExt;
30
31 #[cfg(not(feature = "by_ref_proposal"))]
32 use crate::crypto::HpkePublicKey;
33
34 pub const TEST_GROUP: &[u8] = b"group";
35
36 #[derive(Clone)]
37 pub(crate) struct TestGroup {
38 pub group: Group<TestClientConfig>,
39 }
40
41 impl TestGroup {
42 #[cfg(feature = "external_client")]
43 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
propose(&mut self, proposal: Proposal) -> MlsMessage44 pub(crate) async fn propose(&mut self, proposal: Proposal) -> MlsMessage {
45 self.group.proposal_message(proposal, vec![]).await.unwrap()
46 }
47
48 #[cfg(feature = "by_ref_proposal")]
49 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
update_proposal(&mut self) -> Proposal50 pub(crate) async fn update_proposal(&mut self) -> Proposal {
51 self.group.update_proposal(None, None).await.unwrap()
52 }
53
54 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
join_with_custom_config<F>( &mut self, name: &str, custom_kp: bool, mut config: F, ) -> Result<(TestGroup, MlsMessage), MlsError> where F: FnMut(&mut TestClientConfig),55 pub(crate) async fn join_with_custom_config<F>(
56 &mut self,
57 name: &str,
58 custom_kp: bool,
59 mut config: F,
60 ) -> Result<(TestGroup, MlsMessage), MlsError>
61 where
62 F: FnMut(&mut TestClientConfig),
63 {
64 let (mut new_client, new_key_package) = if custom_kp {
65 test_client_with_key_pkg_custom(
66 self.group.protocol_version(),
67 self.group.cipher_suite(),
68 name,
69 &mut config,
70 )
71 .await
72 } else {
73 test_client_with_key_pkg(
74 self.group.protocol_version(),
75 self.group.cipher_suite(),
76 name,
77 )
78 .await
79 };
80
81 // Add new member to the group
82 let CommitOutput {
83 welcome_messages,
84 ratchet_tree,
85 commit_message,
86 ..
87 } = self
88 .group
89 .commit_builder()
90 .add_member(new_key_package)
91 .unwrap()
92 .build()
93 .await
94 .unwrap();
95
96 // Apply the commit to the original group
97 self.group.apply_pending_commit().await.unwrap();
98
99 config(&mut new_client.config);
100
101 // Group from new member's perspective
102 let (new_group, _) = Group::join(
103 &welcome_messages[0],
104 ratchet_tree,
105 new_client.config.clone(),
106 new_client.signer.clone().unwrap(),
107 )
108 .await?;
109
110 let new_test_group = TestGroup { group: new_group };
111
112 Ok((new_test_group, commit_message))
113 }
114
115 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
join(&mut self, name: &str) -> (TestGroup, MlsMessage)116 pub(crate) async fn join(&mut self, name: &str) -> (TestGroup, MlsMessage) {
117 self.join_with_custom_config(name, false, |_| ())
118 .await
119 .unwrap()
120 }
121
122 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
process_pending_commit( &mut self, ) -> Result<CommitMessageDescription, MlsError>123 pub(crate) async fn process_pending_commit(
124 &mut self,
125 ) -> Result<CommitMessageDescription, MlsError> {
126 self.group.apply_pending_commit().await
127 }
128
129 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
process_message( &mut self, message: MlsMessage, ) -> Result<ReceivedMessage, MlsError>130 pub(crate) async fn process_message(
131 &mut self,
132 message: MlsMessage,
133 ) -> Result<ReceivedMessage, MlsError> {
134 self.group.process_incoming_message(message).await
135 }
136
137 #[cfg(feature = "private_message")]
138 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
make_plaintext(&mut self, content: Content) -> MlsMessage139 pub(crate) async fn make_plaintext(&mut self, content: Content) -> MlsMessage {
140 let auth_content = AuthenticatedContent::new_signed(
141 &self.group.cipher_suite_provider,
142 &self.group.state.context,
143 Sender::Member(*self.group.private_tree.self_index),
144 content,
145 &self.group.signer,
146 WireFormat::PublicMessage,
147 Vec::new(),
148 )
149 .await
150 .unwrap();
151
152 self.group.format_for_wire(auth_content).await.unwrap()
153 }
154 }
155
156 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_group_context(epoch: u64, cipher_suite: CipherSuite) -> GroupContext157 pub(crate) async fn get_test_group_context(epoch: u64, cipher_suite: CipherSuite) -> GroupContext {
158 let cs = test_cipher_suite_provider(cipher_suite);
159
160 GroupContext {
161 protocol_version: TEST_PROTOCOL_VERSION,
162 cipher_suite,
163 group_id: TEST_GROUP.to_vec(),
164 epoch,
165 tree_hash: cs.hash(&[1, 2, 3]).await.unwrap(),
166 confirmed_transcript_hash: cs.hash(&[3, 2, 1]).await.unwrap().into(),
167 extensions: ExtensionList::from(vec![]),
168 }
169 }
170
171 #[cfg(feature = "prior_epoch")]
get_test_group_context_with_id( group_id: Vec<u8>, epoch: u64, cipher_suite: CipherSuite, ) -> GroupContext172 pub(crate) fn get_test_group_context_with_id(
173 group_id: Vec<u8>,
174 epoch: u64,
175 cipher_suite: CipherSuite,
176 ) -> GroupContext {
177 GroupContext {
178 protocol_version: TEST_PROTOCOL_VERSION,
179 cipher_suite,
180 group_id,
181 epoch,
182 tree_hash: vec![],
183 confirmed_transcript_hash: ConfirmedTranscriptHash::from(vec![]),
184 extensions: ExtensionList::from(vec![]),
185 }
186 }
187
group_extensions() -> ExtensionList188 pub(crate) fn group_extensions() -> ExtensionList {
189 let required_capabilities = RequiredCapabilitiesExt::default();
190
191 let mut extensions = ExtensionList::new();
192 extensions.set_from(required_capabilities).unwrap();
193 extensions
194 }
195
lifetime() -> Lifetime196 pub(crate) fn lifetime() -> Lifetime {
197 Lifetime::years(1).unwrap()
198 }
199
200 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_member( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, identifier: &[u8], ) -> (KeyPackageGeneration, SignatureSecretKey)201 pub(crate) async fn test_member(
202 protocol_version: ProtocolVersion,
203 cipher_suite: CipherSuite,
204 identifier: &[u8],
205 ) -> (KeyPackageGeneration, SignatureSecretKey) {
206 let (signing_identity, signing_key) = get_test_signing_identity(cipher_suite, identifier).await;
207
208 let key_package_generator = KeyPackageGenerator {
209 protocol_version,
210 cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
211 signing_identity: &signing_identity,
212 signing_key: &signing_key,
213 identity_provider: &BasicIdentityProvider,
214 };
215
216 let key_package = key_package_generator
217 .generate(
218 lifetime(),
219 get_test_capabilities(),
220 ExtensionList::default(),
221 ExtensionList::default(),
222 )
223 .await
224 .unwrap();
225
226 (key_package, signing_key)
227 }
228
229 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_group_custom( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, extension_types: Vec<ExtensionType>, leaf_extensions: Option<ExtensionList>, commit_options: Option<CommitOptions>, ) -> TestGroup230 pub(crate) async fn test_group_custom(
231 protocol_version: ProtocolVersion,
232 cipher_suite: CipherSuite,
233 extension_types: Vec<ExtensionType>,
234 leaf_extensions: Option<ExtensionList>,
235 commit_options: Option<CommitOptions>,
236 ) -> TestGroup {
237 let leaf_extensions = leaf_extensions.unwrap_or_default();
238 let commit_options = commit_options.unwrap_or_default();
239
240 let (signing_identity, secret_key) = get_test_signing_identity(cipher_suite, b"member").await;
241
242 let group = TestClientBuilder::new_for_test()
243 .leaf_node_extensions(leaf_extensions)
244 .mls_rules(DefaultMlsRules::default().with_commit_options(commit_options))
245 .extension_types(extension_types)
246 .protocol_versions(ProtocolVersion::all())
247 .used_protocol_version(protocol_version)
248 .signing_identity(signing_identity.clone(), secret_key, cipher_suite)
249 .build()
250 .create_group_with_id(TEST_GROUP.to_vec(), group_extensions())
251 .await
252 .unwrap();
253
254 TestGroup { group }
255 }
256
257 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_group( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, ) -> TestGroup258 pub(crate) async fn test_group(
259 protocol_version: ProtocolVersion,
260 cipher_suite: CipherSuite,
261 ) -> TestGroup {
262 test_group_custom(
263 protocol_version,
264 cipher_suite,
265 Default::default(),
266 None,
267 None,
268 )
269 .await
270 }
271
272 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_group_custom_config<F>( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, custom: F, ) -> TestGroup where F: FnOnce(TestClientBuilder) -> TestClientBuilder,273 pub(crate) async fn test_group_custom_config<F>(
274 protocol_version: ProtocolVersion,
275 cipher_suite: CipherSuite,
276 custom: F,
277 ) -> TestGroup
278 where
279 F: FnOnce(TestClientBuilder) -> TestClientBuilder,
280 {
281 let (signing_identity, secret_key) = get_test_signing_identity(cipher_suite, b"member").await;
282
283 let client_builder = TestClientBuilder::new_for_test().used_protocol_version(protocol_version);
284
285 let group = custom(client_builder)
286 .signing_identity(signing_identity.clone(), secret_key, cipher_suite)
287 .build()
288 .create_group_with_id(TEST_GROUP.to_vec(), group_extensions())
289 .await
290 .unwrap();
291
292 TestGroup { group }
293 }
294
295 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_n_member_group( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, num_members: usize, ) -> Vec<TestGroup>296 pub(crate) async fn test_n_member_group(
297 protocol_version: ProtocolVersion,
298 cipher_suite: CipherSuite,
299 num_members: usize,
300 ) -> Vec<TestGroup> {
301 let group = test_group(protocol_version, cipher_suite).await;
302
303 let mut groups = vec![group];
304
305 for i in 1..num_members {
306 let (new_group, commit) = groups.get_mut(0).unwrap().join(&format!("name {i}")).await;
307 process_commit(&mut groups, commit, 0).await;
308 groups.push(new_group);
309 }
310
311 groups
312 }
313
314 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
process_commit(groups: &mut [TestGroup], commit: MlsMessage, excluded: u32)315 pub(crate) async fn process_commit(groups: &mut [TestGroup], commit: MlsMessage, excluded: u32) {
316 for g in groups
317 .iter_mut()
318 .filter(|g| g.group.current_member_index() != excluded)
319 {
320 g.process_message(commit.clone()).await.unwrap();
321 }
322 }
323
get_test_25519_key(key_byte: u8) -> HpkePublicKey324 pub(crate) fn get_test_25519_key(key_byte: u8) -> HpkePublicKey {
325 vec![key_byte; 32].into()
326 }
327
328 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_groups_with_features( n: usize, extensions: ExtensionList, leaf_extensions: ExtensionList, ) -> Vec<Group<TestClientConfig>>329 pub(crate) async fn get_test_groups_with_features(
330 n: usize,
331 extensions: ExtensionList,
332 leaf_extensions: ExtensionList,
333 ) -> Vec<Group<TestClientConfig>> {
334 let mut clients = Vec::new();
335
336 for i in 0..n {
337 let (identity, secret_key) =
338 get_test_signing_identity(TEST_CIPHER_SUITE, format!("member{i}").as_bytes()).await;
339
340 clients.push(
341 TestClientBuilder::new_for_test()
342 .extension_type(999.into())
343 .leaf_node_extensions(leaf_extensions.clone())
344 .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
345 .build(),
346 );
347 }
348
349 let group = clients[0]
350 .create_group_with_id(b"TEST GROUP".to_vec(), extensions)
351 .await
352 .unwrap();
353
354 let mut groups = vec![group];
355
356 for client in clients.iter().skip(1) {
357 let key_package = client.generate_key_package_message().await.unwrap();
358
359 let commit_output = groups[0]
360 .commit_builder()
361 .add_member(key_package)
362 .unwrap()
363 .build()
364 .await
365 .unwrap();
366
367 groups[0].apply_pending_commit().await.unwrap();
368
369 for group in groups.iter_mut().skip(1) {
370 group
371 .process_incoming_message(commit_output.commit_message.clone())
372 .await
373 .unwrap();
374 }
375
376 groups.push(
377 client
378 .join_group(None, &commit_output.welcome_messages[0])
379 .await
380 .unwrap()
381 .0,
382 );
383 }
384
385 groups
386 }
387
random_bytes(count: usize) -> Vec<u8>388 pub fn random_bytes(count: usize) -> Vec<u8> {
389 let mut buf = vec![0; count];
390 rand::thread_rng().fill_bytes(&mut buf);
391 buf
392 }
393
394 pub(crate) struct GroupWithoutKeySchedule {
395 inner: Group<TestClientConfig>,
396 pub secrets: Option<(TreeKemPrivate, PathSecret)>,
397 pub provisional_public_state: Option<ProvisionalState>,
398 }
399
400 impl Deref for GroupWithoutKeySchedule {
401 type Target = Group<TestClientConfig>;
402
403 #[cfg_attr(coverage_nightly, coverage(off))]
deref(&self) -> &Self::Target404 fn deref(&self) -> &Self::Target {
405 &self.inner
406 }
407 }
408
409 impl DerefMut for GroupWithoutKeySchedule {
410 #[cfg_attr(coverage_nightly, coverage(off))]
deref_mut(&mut self) -> &mut Self::Target411 fn deref_mut(&mut self) -> &mut Self::Target {
412 &mut self.inner
413 }
414 }
415
416 #[cfg(feature = "rfc_compliant")]
417 impl GroupWithoutKeySchedule {
418 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
new(cs: CipherSuite) -> Self419 pub async fn new(cs: CipherSuite) -> Self {
420 Self {
421 inner: test_group(TEST_PROTOCOL_VERSION, cs).await.group,
422 secrets: None,
423 provisional_public_state: None,
424 }
425 }
426 }
427
428 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
429 #[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
430 #[cfg_attr(
431 all(not(target_arch = "wasm32"), mls_build_async),
432 maybe_async::must_be_async
433 )]
434 impl MessageProcessor for GroupWithoutKeySchedule {
435 type CipherSuiteProvider = <Group<TestClientConfig> as MessageProcessor>::CipherSuiteProvider;
436 type OutputType = <Group<TestClientConfig> as MessageProcessor>::OutputType;
437 type PreSharedKeyStorage = <Group<TestClientConfig> as MessageProcessor>::PreSharedKeyStorage;
438 type IdentityProvider = <Group<TestClientConfig> as MessageProcessor>::IdentityProvider;
439 type MlsRules = <Group<TestClientConfig> as MessageProcessor>::MlsRules;
440
group_state(&self) -> &GroupState441 fn group_state(&self) -> &GroupState {
442 self.inner.group_state()
443 }
444
445 #[cfg_attr(coverage_nightly, coverage(off))]
group_state_mut(&mut self) -> &mut GroupState446 fn group_state_mut(&mut self) -> &mut GroupState {
447 self.inner.group_state_mut()
448 }
449
mls_rules(&self) -> Self::MlsRules450 fn mls_rules(&self) -> Self::MlsRules {
451 self.inner.mls_rules()
452 }
453
identity_provider(&self) -> Self::IdentityProvider454 fn identity_provider(&self) -> Self::IdentityProvider {
455 self.inner.identity_provider()
456 }
457
cipher_suite_provider(&self) -> &Self::CipherSuiteProvider458 fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider {
459 self.inner.cipher_suite_provider()
460 }
461
psk_storage(&self) -> Self::PreSharedKeyStorage462 fn psk_storage(&self) -> Self::PreSharedKeyStorage {
463 self.inner.psk_storage()
464 }
465
can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool466 fn can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool {
467 self.inner.can_continue_processing(provisional_state)
468 }
469
470 #[cfg(feature = "private_message")]
471 #[cfg_attr(coverage_nightly, coverage(off))]
min_epoch_available(&self) -> Option<u64>472 fn min_epoch_available(&self) -> Option<u64> {
473 self.inner.min_epoch_available()
474 }
475
apply_update_path( &mut self, sender: LeafIndex, update_path: &ValidatedUpdatePath, provisional_state: &mut ProvisionalState, ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError>476 async fn apply_update_path(
477 &mut self,
478 sender: LeafIndex,
479 update_path: &ValidatedUpdatePath,
480 provisional_state: &mut ProvisionalState,
481 ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
482 self.inner
483 .apply_update_path(sender, update_path, provisional_state)
484 .await
485 }
486
487 #[cfg(feature = "private_message")]
488 #[cfg_attr(coverage_nightly, coverage(off))]
process_ciphertext( &mut self, cipher_text: &PrivateMessage, ) -> Result<EventOrContent<Self::OutputType>, MlsError>489 async fn process_ciphertext(
490 &mut self,
491 cipher_text: &PrivateMessage,
492 ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
493 self.inner.process_ciphertext(cipher_text).await
494 }
495
496 #[cfg_attr(coverage_nightly, coverage(off))]
verify_plaintext_authentication( &self, message: PublicMessage, ) -> Result<EventOrContent<Self::OutputType>, MlsError>497 async fn verify_plaintext_authentication(
498 &self,
499 message: PublicMessage,
500 ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
501 self.inner.verify_plaintext_authentication(message).await
502 }
503
update_key_schedule( &mut self, secrets: Option<(TreeKemPrivate, PathSecret)>, _interim_transcript_hash: InterimTranscriptHash, _confirmation_tag: &ConfirmationTag, provisional_public_state: ProvisionalState, ) -> Result<(), MlsError>504 async fn update_key_schedule(
505 &mut self,
506 secrets: Option<(TreeKemPrivate, PathSecret)>,
507 _interim_transcript_hash: InterimTranscriptHash,
508 _confirmation_tag: &ConfirmationTag,
509 provisional_public_state: ProvisionalState,
510 ) -> Result<(), MlsError> {
511 self.provisional_public_state = Some(provisional_public_state);
512 self.secrets = secrets;
513 Ok(())
514 }
515
516 #[cfg(feature = "private_message")]
517 #[cfg_attr(coverage_nightly, coverage(off))]
self_index(&self) -> Option<LeafIndex>518 fn self_index(&self) -> Option<LeafIndex> {
519 <Group<TestClientConfig> as MessageProcessor>::self_index(&self.inner)
520 }
521 }
522