1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use assert_matches::assert_matches;
6 use cfg_if::cfg_if;
7 use mls_rs::client_builder::MlsConfig;
8 use mls_rs::error::MlsError;
9 use mls_rs::group::proposal::Proposal;
10 use mls_rs::group::ReceivedMessage;
11 use mls_rs::identity::SigningIdentity;
12 use mls_rs::mls_rules::CommitOptions;
13 use mls_rs::ExtensionList;
14 use mls_rs::MlsMessage;
15 use mls_rs::ProtocolVersion;
16 use mls_rs::{CipherSuite, Group};
17 use mls_rs::{Client, CryptoProvider};
18 use mls_rs_core::crypto::CipherSuiteProvider;
19 use rand::prelude::SliceRandom;
20 use rand::RngCore;
21 
22 use mls_rs::test_utils::{all_process_message, get_test_basic_credential};
23 
24 #[cfg(mls_build_async)]
25 use futures::Future;
26 
27 cfg_if! {
28     if #[cfg(target_arch = "wasm32")] {
29         use mls_rs_crypto_webcrypto::WebCryptoProvider as TestCryptoProvider;
30     } else {
31         use mls_rs_crypto_openssl::OpensslCryptoProvider as TestCryptoProvider;
32     }
33 }
34 
35 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
generate_client( cipher_suite: CipherSuite, protocol_version: ProtocolVersion, id: usize, encrypt_controls: bool, ) -> Client<impl MlsConfig>36 async fn generate_client(
37     cipher_suite: CipherSuite,
38     protocol_version: ProtocolVersion,
39     id: usize,
40     encrypt_controls: bool,
41 ) -> Client<impl MlsConfig> {
42     mls_rs::test_utils::generate_basic_client(
43         cipher_suite,
44         protocol_version,
45         id,
46         None,
47         encrypt_controls,
48         &TestCryptoProvider::default(),
49         None,
50     )
51     .await
52 }
53 
54 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_groups( version: ProtocolVersion, cipher_suite: CipherSuite, num_participants: usize, encrypt_controls: bool, ) -> Vec<Group<impl MlsConfig>>55 pub async fn get_test_groups(
56     version: ProtocolVersion,
57     cipher_suite: CipherSuite,
58     num_participants: usize,
59     encrypt_controls: bool,
60 ) -> Vec<Group<impl MlsConfig>> {
61     mls_rs::test_utils::get_test_groups(
62         version,
63         cipher_suite,
64         num_participants,
65         None,
66         encrypt_controls,
67         &TestCryptoProvider::default(),
68     )
69     .await
70 }
71 
72 use rand::seq::IteratorRandom;
73 
74 #[cfg(target_arch = "wasm32")]
75 wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
76 
77 #[cfg(target_arch = "wasm32")]
78 use wasm_bindgen_test::wasm_bindgen_test as futures_test;
79 
80 #[cfg(all(mls_build_async, not(target_arch = "wasm32")))]
81 use futures_test::test as futures_test;
82 
83 #[cfg(feature = "private_message")]
84 #[cfg(mls_build_async)]
test_on_all_params<F, Fut>(test: F) where F: Fn(ProtocolVersion, CipherSuite, usize, bool) -> Fut, Fut: Future<Output = ()>,85 async fn test_on_all_params<F, Fut>(test: F)
86 where
87     F: Fn(ProtocolVersion, CipherSuite, usize, bool) -> Fut,
88     Fut: Future<Output = ()>,
89 {
90     for version in ProtocolVersion::all() {
91         for cs in TestCryptoProvider::all_supported_cipher_suites() {
92             for encrypt_controls in [true, false] {
93                 test(version, cs, 10, encrypt_controls).await;
94             }
95         }
96     }
97 }
98 
99 #[cfg(feature = "private_message")]
100 #[cfg(not(mls_build_async))]
test_on_all_params<F>(test: F) where F: Fn(ProtocolVersion, CipherSuite, usize, bool),101 fn test_on_all_params<F>(test: F)
102 where
103     F: Fn(ProtocolVersion, CipherSuite, usize, bool),
104 {
105     for version in ProtocolVersion::all() {
106         for cs in TestCryptoProvider::all_supported_cipher_suites() {
107             for encrypt_controls in [true, false] {
108                 test(version, cs, 10, encrypt_controls);
109             }
110         }
111     }
112 }
113 
114 #[cfg(not(feature = "private_message"))]
115 #[cfg(mls_build_async)]
test_on_all_params<F, Fut>(test: F) where F: Fn(ProtocolVersion, CipherSuite, usize, bool) -> Fut, Fut: Future<Output = ()>,116 async fn test_on_all_params<F, Fut>(test: F)
117 where
118     F: Fn(ProtocolVersion, CipherSuite, usize, bool) -> Fut,
119     Fut: Future<Output = ()>,
120 {
121     test_on_all_params_plaintext(test).await;
122 }
123 
124 #[cfg(not(feature = "private_message"))]
125 #[cfg(not(mls_build_async))]
test_on_all_params<F>(test: F) where F: Fn(ProtocolVersion, CipherSuite, usize, bool),126 fn test_on_all_params<F>(test: F)
127 where
128     F: Fn(ProtocolVersion, CipherSuite, usize, bool),
129 {
130     test_on_all_params_plaintext(test);
131 }
132 
133 #[cfg(mls_build_async)]
test_on_all_params_plaintext<F, Fut>(test: F) where F: Fn(ProtocolVersion, CipherSuite, usize, bool) -> Fut, Fut: Future<Output = ()>,134 async fn test_on_all_params_plaintext<F, Fut>(test: F)
135 where
136     F: Fn(ProtocolVersion, CipherSuite, usize, bool) -> Fut,
137     Fut: Future<Output = ()>,
138 {
139     for version in ProtocolVersion::all() {
140         for cs in TestCryptoProvider::all_supported_cipher_suites() {
141             test(version, cs, 10, false).await;
142         }
143     }
144 }
145 
146 #[cfg(not(mls_build_async))]
test_on_all_params_plaintext<F>(test: F) where F: Fn(ProtocolVersion, CipherSuite, usize, bool),147 fn test_on_all_params_plaintext<F>(test: F)
148 where
149     F: Fn(ProtocolVersion, CipherSuite, usize, bool),
150 {
151     for version in ProtocolVersion::all() {
152         for cs in TestCryptoProvider::all_supported_cipher_suites() {
153             test(version, cs, 10, false);
154         }
155     }
156 }
157 
158 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_create( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, _n_participants: usize, encrypt_controls: bool, )159 async fn test_create(
160     protocol_version: ProtocolVersion,
161     cipher_suite: CipherSuite,
162     _n_participants: usize,
163     encrypt_controls: bool,
164 ) {
165     let alice = generate_client(cipher_suite, protocol_version, 0, encrypt_controls).await;
166     let bob = generate_client(cipher_suite, protocol_version, 1, encrypt_controls).await;
167     let bob_key_pkg = bob.generate_key_package_message().await.unwrap();
168 
169     // Alice creates a group and adds bob
170     let mut alice_group = alice
171         .create_group_with_id(b"group".to_vec(), ExtensionList::default())
172         .await
173         .unwrap();
174 
175     let welcome = &alice_group
176         .commit_builder()
177         .add_member(bob_key_pkg)
178         .unwrap()
179         .build()
180         .await
181         .unwrap()
182         .welcome_messages[0];
183 
184     // Upon server confirmation, alice applies the commit to her own state
185     alice_group.apply_pending_commit().await.unwrap();
186 
187     // Bob receives the welcome message and joins the group
188     let (bob_group, _) = bob.join_group(None, welcome).await.unwrap();
189 
190     assert!(Group::equal_group_state(&alice_group, &bob_group));
191 }
192 
193 #[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
test_create_group()194 async fn test_create_group() {
195     test_on_all_params(test_create).await;
196 }
197 
198 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_empty_commits( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, participants: usize, encrypt_controls: bool, )199 async fn test_empty_commits(
200     protocol_version: ProtocolVersion,
201     cipher_suite: CipherSuite,
202     participants: usize,
203     encrypt_controls: bool,
204 ) {
205     let mut groups = get_test_groups(
206         protocol_version,
207         cipher_suite,
208         participants,
209         encrypt_controls,
210     )
211     .await;
212 
213     // Loop through each participant and send a path update
214 
215     for i in 0..groups.len() {
216         // Create the commit
217         let commit_output = groups[i].commit(Vec::new()).await.unwrap();
218 
219         assert!(commit_output.welcome_messages.is_empty());
220 
221         let index = groups[i].current_member_index() as usize;
222         all_process_message(&mut groups, &commit_output.commit_message, index, true).await;
223 
224         for other_group in groups.iter() {
225             assert!(Group::equal_group_state(other_group, &groups[i]));
226         }
227     }
228 }
229 
230 #[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
test_group_path_updates()231 async fn test_group_path_updates() {
232     test_on_all_params(test_empty_commits).await;
233 }
234 
235 #[cfg(feature = "by_ref_proposal")]
236 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_update_proposals( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, participants: usize, encrypt_controls: bool, )237 async fn test_update_proposals(
238     protocol_version: ProtocolVersion,
239     cipher_suite: CipherSuite,
240     participants: usize,
241     encrypt_controls: bool,
242 ) {
243     let mut groups = get_test_groups(
244         protocol_version,
245         cipher_suite,
246         participants,
247         encrypt_controls,
248     )
249     .await;
250 
251     // Create an update from the ith member, have the ith + 1 member commit it
252     for i in 0..groups.len() - 1 {
253         let update_proposal_msg = groups[i].propose_update(Vec::new()).await.unwrap();
254 
255         let sender = groups[i].current_member_index() as usize;
256         all_process_message(&mut groups, &update_proposal_msg, sender, false).await;
257 
258         // Everyone receives the commit
259         let committer_index = i + 1;
260 
261         let commit_output = groups[committer_index].commit(Vec::new()).await.unwrap();
262 
263         assert!(commit_output.welcome_messages.is_empty());
264 
265         let commit = commit_output.commit_message;
266 
267         all_process_message(&mut groups, &commit, committer_index, true).await;
268 
269         groups
270             .iter()
271             .for_each(|g| assert!(Group::equal_group_state(g, &groups[0])));
272     }
273 }
274 
275 #[cfg(feature = "by_ref_proposal")]
276 #[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
test_group_update_proposals()277 async fn test_group_update_proposals() {
278     test_on_all_params(test_update_proposals).await;
279 }
280 
281 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_remove_proposals( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, participants: usize, encrypt_controls: bool, )282 async fn test_remove_proposals(
283     protocol_version: ProtocolVersion,
284     cipher_suite: CipherSuite,
285     participants: usize,
286     encrypt_controls: bool,
287 ) {
288     let mut groups = get_test_groups(
289         protocol_version,
290         cipher_suite,
291         participants,
292         encrypt_controls,
293     )
294     .await;
295 
296     // Remove people from the group one at a time
297     while groups.len() > 1 {
298         let removed_and_committer = (0..groups.len()).choose_multiple(&mut rand::thread_rng(), 2);
299 
300         let to_remove = removed_and_committer[0];
301         let committer = removed_and_committer[1];
302         let to_remove_index = groups[to_remove].current_member_index();
303 
304         let epoch_before_remove = groups[committer].current_epoch();
305 
306         let commit_output = groups[committer]
307             .commit_builder()
308             .remove_member(to_remove_index)
309             .unwrap()
310             .build()
311             .await
312             .unwrap();
313 
314         assert!(commit_output.welcome_messages.is_empty());
315 
316         let commit = commit_output.commit_message;
317         let committer_index = groups[committer].current_member_index() as usize;
318         all_process_message(&mut groups, &commit, committer_index, true).await;
319 
320         // Check that remove was effective
321         for (i, group) in groups.iter().enumerate() {
322             if i == to_remove {
323                 assert_eq!(group.current_epoch(), epoch_before_remove);
324             } else {
325                 assert_eq!(group.current_epoch(), epoch_before_remove + 1);
326                 assert!(group.roster().member_with_index(to_remove_index).is_err());
327             }
328         }
329 
330         groups.retain(|group| group.current_member_index() != to_remove_index);
331 
332         for one_group in groups.iter() {
333             assert!(Group::equal_group_state(one_group, &groups[0]))
334         }
335     }
336 }
337 
338 #[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
test_group_remove_proposals()339 async fn test_group_remove_proposals() {
340     test_on_all_params(test_remove_proposals).await;
341 }
342 
343 #[cfg(feature = "private_message")]
344 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_application_messages( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, participants: usize, encrypt_controls: bool, )345 async fn test_application_messages(
346     protocol_version: ProtocolVersion,
347     cipher_suite: CipherSuite,
348     participants: usize,
349     encrypt_controls: bool,
350 ) {
351     let message_count = 20;
352 
353     let mut groups = get_test_groups(
354         protocol_version,
355         cipher_suite,
356         participants,
357         encrypt_controls,
358     )
359     .await;
360 
361     // Loop through each participant and send application messages
362     for i in 0..groups.len() {
363         let mut test_message = vec![0; 1024];
364         rand::thread_rng().fill_bytes(&mut test_message);
365 
366         for _ in 0..message_count {
367             // Encrypt the application message
368             let ciphertext = groups[i]
369                 .encrypt_application_message(&test_message, Vec::new())
370                 .await
371                 .unwrap();
372 
373             let sender_index = groups[i].current_member_index();
374 
375             for g in groups.iter_mut() {
376                 if g.current_member_index() != sender_index {
377                     let decrypted = g
378                         .process_incoming_message(ciphertext.clone())
379                         .await
380                         .unwrap();
381 
382                     assert_matches!(decrypted, ReceivedMessage::ApplicationMessage(m) if m.data() == test_message);
383                 }
384             }
385         }
386     }
387 }
388 
389 #[cfg(all(feature = "private_message", feature = "out_of_order"))]
390 #[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
test_out_of_order_application_messages()391 async fn test_out_of_order_application_messages() {
392     let mut groups =
393         get_test_groups(ProtocolVersion::MLS_10, CipherSuite::P256_AES128, 2, false).await;
394 
395     let mut alice_group = groups[0].clone();
396     let bob_group = &mut groups[1];
397 
398     let ciphertext = alice_group
399         .encrypt_application_message(&[0], Vec::new())
400         .await
401         .unwrap();
402 
403     let mut ciphertexts = vec![ciphertext];
404 
405     ciphertexts.push(
406         alice_group
407             .encrypt_application_message(&[1], Vec::new())
408             .await
409             .unwrap(),
410     );
411 
412     let commit = alice_group.commit(Vec::new()).await.unwrap().commit_message;
413 
414     alice_group.apply_pending_commit().await.unwrap();
415 
416     bob_group.process_incoming_message(commit).await.unwrap();
417 
418     ciphertexts.push(
419         alice_group
420             .encrypt_application_message(&[2], Vec::new())
421             .await
422             .unwrap(),
423     );
424 
425     ciphertexts.push(
426         alice_group
427             .encrypt_application_message(&[3], Vec::new())
428             .await
429             .unwrap(),
430     );
431 
432     for i in [3, 2, 1, 0] {
433         let res = bob_group
434             .process_incoming_message(ciphertexts[i].clone())
435             .await
436             .unwrap();
437 
438         assert_matches!(
439             res,
440             ReceivedMessage::ApplicationMessage(m) if m.data() == [i as u8]
441         );
442     }
443 }
444 
445 #[cfg(feature = "private_message")]
446 #[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
test_group_application_messages()447 async fn test_group_application_messages() {
448     test_on_all_params(test_application_messages).await
449 }
450 
451 #[cfg(feature = "private_message")]
452 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
processing_message_from_self_returns_error( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, _n_participants: usize, encrypt_controls: bool, )453 async fn processing_message_from_self_returns_error(
454     protocol_version: ProtocolVersion,
455     cipher_suite: CipherSuite,
456     _n_participants: usize,
457     encrypt_controls: bool,
458 ) {
459     let mut creator_group =
460         get_test_groups(protocol_version, cipher_suite, 1, encrypt_controls).await;
461     let creator_group = &mut creator_group[0];
462 
463     let msg = creator_group
464         .encrypt_application_message(b"hello self", vec![])
465         .await
466         .unwrap();
467 
468     let error = creator_group
469         .process_incoming_message(msg)
470         .await
471         .unwrap_err();
472 
473     assert_matches!(error, MlsError::CantProcessMessageFromSelf);
474 }
475 
476 #[cfg(feature = "private_message")]
477 #[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
test_processing_message_from_self_returns_error()478 async fn test_processing_message_from_self_returns_error() {
479     test_on_all_params(processing_message_from_self_returns_error).await;
480 }
481 
482 #[cfg(feature = "by_ref_proposal")]
483 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
external_commits_work( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, _n_participants: usize, _encrypt_controls: bool, )484 async fn external_commits_work(
485     protocol_version: ProtocolVersion,
486     cipher_suite: CipherSuite,
487     _n_participants: usize,
488     _encrypt_controls: bool,
489 ) {
490     let creator = generate_client(cipher_suite, protocol_version, 0, false).await;
491 
492     let creator_group = creator
493         .create_group_with_id(b"group".to_vec(), ExtensionList::default())
494         .await
495         .unwrap();
496 
497     const PARTICIPANT_COUNT: usize = 10;
498 
499     let mut others = Vec::new();
500 
501     for i in 1..PARTICIPANT_COUNT {
502         others.push(generate_client(cipher_suite, protocol_version, i, Default::default()).await)
503     }
504 
505     let mut groups = vec![creator_group];
506 
507     for client in &others {
508         let existing_group = groups.choose_mut(&mut rand::thread_rng()).unwrap();
509 
510         let group_info = existing_group
511             .group_info_message_allowing_ext_commit(true)
512             .await
513             .unwrap();
514 
515         let (new_group, commit) = client
516             .external_commit_builder()
517             .unwrap()
518             .build(group_info)
519             .await
520             .unwrap();
521 
522         for group in groups.iter_mut() {
523             group
524                 .process_incoming_message(commit.clone())
525                 .await
526                 .unwrap();
527         }
528 
529         groups.push(new_group);
530     }
531 
532     assert!(groups
533         .iter()
534         .all(|group| group.roster().members_iter().count() == PARTICIPANT_COUNT));
535 
536     for i in 0..groups.len() {
537         let message = groups[i].propose_remove(0, Vec::new()).await.unwrap();
538 
539         for (_, group) in groups.iter_mut().enumerate().filter(|&(j, _)| i != j) {
540             let processed = group
541                 .process_incoming_message(message.clone())
542                 .await
543                 .unwrap();
544 
545             if let ReceivedMessage::Proposal(p) = &processed {
546                 if let Proposal::Remove(r) = &p.proposal {
547                     if r.to_remove() == 0 {
548                         continue;
549                     }
550                 }
551             }
552 
553             panic!("expected a proposal, got {processed:?}");
554         }
555     }
556 }
557 
558 #[cfg(feature = "by_ref_proposal")]
559 #[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
test_external_commits()560 async fn test_external_commits() {
561     test_on_all_params_plaintext(external_commits_work).await
562 }
563 
564 #[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
test_remove_nonexisting_leaf()565 async fn test_remove_nonexisting_leaf() {
566     let mut groups =
567         get_test_groups(ProtocolVersion::MLS_10, CipherSuite::P256_AES128, 10, false).await;
568 
569     groups[0]
570         .commit_builder()
571         .remove_member(5)
572         .unwrap()
573         .build()
574         .await
575         .unwrap();
576     groups[0].apply_pending_commit().await.unwrap();
577 
578     // Leaf index out of bounds
579     assert!(groups[0].commit_builder().remove_member(13).is_err());
580 
581     // Removing blank leaf causes error
582     assert!(groups[0].commit_builder().remove_member(5).is_err());
583 }
584 
585 #[cfg(feature = "psk")]
586 #[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
reinit_works()587 async fn reinit_works() {
588     let suite1 = CipherSuite::P256_AES128;
589 
590     let Some(suite2) = CipherSuite::all()
591         .find(|cs| cs != &suite1 && TestCryptoProvider::all_supported_cipher_suites().contains(cs))
592     else {
593         return;
594     };
595 
596     let version = ProtocolVersion::MLS_10;
597 
598     let alice1 = generate_client(suite1, version, 1, Default::default()).await;
599     let bob1 = generate_client(suite1, version, 2, Default::default()).await;
600 
601     // Create a group with 2 parties
602     let mut alice_group = alice1.create_group(ExtensionList::new()).await.unwrap();
603     let kp = bob1.generate_key_package_message().await.unwrap();
604 
605     let welcome = &alice_group
606         .commit_builder()
607         .add_member(kp)
608         .unwrap()
609         .build()
610         .await
611         .unwrap()
612         .welcome_messages[0];
613 
614     alice_group.apply_pending_commit().await.unwrap();
615 
616     let (mut bob_group, _) = bob1.join_group(None, welcome).await.unwrap();
617 
618     // Alice proposes reinit
619     let reinit_proposal_message = alice_group
620         .propose_reinit(
621             None,
622             ProtocolVersion::MLS_10,
623             suite2,
624             ExtensionList::default(),
625             Vec::new(),
626         )
627         .await
628         .unwrap();
629 
630     // Bob commits the reinit
631     bob_group
632         .process_incoming_message(reinit_proposal_message)
633         .await
634         .unwrap();
635 
636     let commit = bob_group.commit(Vec::new()).await.unwrap().commit_message;
637 
638     // Both process Bob's commit
639 
640     #[cfg(feature = "state_update")]
641     {
642         let state_update = bob_group.apply_pending_commit().await.unwrap().state_update;
643         assert!(!state_update.is_active() && state_update.is_pending_reinit());
644     }
645 
646     #[cfg(not(feature = "state_update"))]
647     bob_group.apply_pending_commit().await.unwrap();
648 
649     let message = alice_group.process_incoming_message(commit).await.unwrap();
650 
651     #[cfg(feature = "state_update")]
652     if let ReceivedMessage::Commit(commit_description) = message {
653         assert!(
654             !commit_description.state_update.is_active()
655                 && commit_description.state_update.is_pending_reinit()
656         );
657     }
658 
659     #[cfg(not(feature = "state_update"))]
660     assert_matches!(message, ReceivedMessage::Commit(_));
661 
662     // They can't create new epochs anymore
663     let res = alice_group.commit(Vec::new()).await;
664     assert!(res.is_err());
665 
666     let res = bob_group.commit(Vec::new()).await;
667     assert!(res.is_err());
668 
669     // Get reinit clients for alice and bob
670     let (secret_key, public_key) = TestCryptoProvider::new()
671         .cipher_suite_provider(suite2)
672         .unwrap()
673         .signature_key_generate()
674         .await
675         .unwrap();
676 
677     let identity = SigningIdentity::new(get_test_basic_credential(b"bob".to_vec()), public_key);
678 
679     let bob2 = bob_group
680         .get_reinit_client(Some(secret_key), Some(identity))
681         .unwrap();
682 
683     let (secret_key, public_key) = TestCryptoProvider::new()
684         .cipher_suite_provider(suite2)
685         .unwrap()
686         .signature_key_generate()
687         .await
688         .unwrap();
689 
690     let identity = SigningIdentity::new(get_test_basic_credential(b"alice".to_vec()), public_key);
691 
692     let alice2 = alice_group
693         .get_reinit_client(Some(secret_key), Some(identity))
694         .unwrap();
695 
696     // Bob produces key package, alice commits, bob joins
697     let kp = bob2.generate_key_package().await.unwrap();
698     let (mut alice_group, welcome) = alice2.commit(vec![kp]).await.unwrap();
699     let (mut bob_group, _) = bob2.join(&welcome[0], None).await.unwrap();
700 
701     assert!(bob_group.cipher_suite() == suite2);
702 
703     // They can talk
704     let carol = generate_client(suite2, version, 3, Default::default()).await;
705 
706     let kp = carol.generate_key_package_message().await.unwrap();
707 
708     let commit_output = alice_group
709         .commit_builder()
710         .add_member(kp)
711         .unwrap()
712         .build()
713         .await
714         .unwrap();
715 
716     alice_group.apply_pending_commit().await.unwrap();
717 
718     bob_group
719         .process_incoming_message(commit_output.commit_message)
720         .await
721         .unwrap();
722 
723     carol
724         .join_group(None, &commit_output.welcome_messages[0])
725         .await
726         .unwrap();
727 }
728 
729 #[cfg(feature = "by_ref_proposal")]
730 #[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
external_joiner_can_process_siblings_update()731 async fn external_joiner_can_process_siblings_update() {
732     let mut groups =
733         get_test_groups(ProtocolVersion::MLS_10, CipherSuite::P256_AES128, 3, false).await;
734 
735     // Remove leaf 1 s.t. the external joiner joins in its place
736     let c = groups[0]
737         .commit_builder()
738         .remove_member(1)
739         .unwrap()
740         .build()
741         .await
742         .unwrap();
743 
744     all_process_message(&mut groups, &c.commit_message, 0, true).await;
745 
746     let info = groups[0]
747         .group_info_message_allowing_ext_commit(true)
748         .await
749         .unwrap();
750 
751     // Create the external joiner and join
752     let new_client = generate_client(
753         CipherSuite::P256_AES128,
754         ProtocolVersion::MLS_10,
755         0xabba,
756         false,
757     )
758     .await;
759 
760     let (mut group, commit) = new_client.commit_external(info).await.unwrap();
761 
762     all_process_message(&mut groups, &commit, 1, false).await;
763     groups.remove(1);
764 
765     // New client's sibling proposes an update to blank their common parent
766     let p = groups[0].propose_update(Vec::new()).await.unwrap();
767     all_process_message(&mut groups, &p, 0, false).await;
768     group.process_incoming_message(p).await.unwrap();
769 
770     // Some other member commits
771     let c = groups[1].commit(Vec::new()).await.unwrap().commit_message;
772     all_process_message(&mut groups, &c, 2, true).await;
773     group.process_incoming_message(c).await.unwrap();
774 }
775 
776 #[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
weird_tree_scenario()777 async fn weird_tree_scenario() {
778     let mut groups =
779         get_test_groups(ProtocolVersion::MLS_10, CipherSuite::P256_AES128, 17, false).await;
780 
781     let to_remove = [0u32, 2, 5, 7, 8, 9, 15];
782 
783     let mut builder = groups[14].commit_builder();
784 
785     for idx in to_remove.iter() {
786         builder = builder.remove_member(*idx).unwrap();
787     }
788 
789     let commit = builder.build().await.unwrap();
790 
791     for idx in to_remove.into_iter().rev() {
792         groups.remove(idx as usize);
793     }
794 
795     all_process_message(&mut groups, &commit.commit_message, 14, true).await;
796 
797     let mut builder = groups.last_mut().unwrap().commit_builder();
798 
799     for idx in 0..7 {
800         builder = builder
801             .add_member(fake_key_package(5555555 + idx).await)
802             .unwrap()
803     }
804 
805     let commit = builder.remove_member(1).unwrap().build().await.unwrap();
806 
807     let idx = groups.last().unwrap().current_member_index() as usize;
808 
809     all_process_message(&mut groups, &commit.commit_message, idx, true).await;
810 }
811 
812 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
fake_key_package(id: usize) -> MlsMessage813 async fn fake_key_package(id: usize) -> MlsMessage {
814     generate_client(CipherSuite::P256_AES128, ProtocolVersion::MLS_10, id, false)
815         .await
816         .generate_key_package_message()
817         .await
818         .unwrap()
819 }
820 
821 #[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
external_info_from_commit_allows_to_join()822 async fn external_info_from_commit_allows_to_join() {
823     let cs = CipherSuite::P256_AES128;
824     let version = ProtocolVersion::MLS_10;
825 
826     let mut alice = mls_rs::test_utils::get_test_groups(
827         version,
828         cs,
829         1,
830         Some(CommitOptions::new().with_allow_external_commit(true)),
831         false,
832         &TestCryptoProvider::default(),
833     )
834     .await
835     .remove(0);
836 
837     let commit = alice.commit(vec![]).await.unwrap();
838     alice.apply_pending_commit().await.unwrap();
839     let bob = generate_client(cs, version, 0xdead, false).await;
840 
841     let (_bob, commit) = bob
842         .commit_external(commit.external_commit_group_info.unwrap())
843         .await
844         .unwrap();
845 
846     alice.process_incoming_message(commit).await.unwrap();
847 }
848