1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use core::ops::{Deref, DerefMut};
6 
7 use alloc::format;
8 use rand::RngCore;
9 
10 use super::*;
11 use crate::{
12     client::{
13         test_utils::{
14             test_client_with_key_pkg, test_client_with_key_pkg_custom, TEST_CIPHER_SUITE,
15             TEST_PROTOCOL_VERSION,
16         },
17         MlsError,
18     },
19     client_builder::test_utils::{TestClientBuilder, TestClientConfig},
20     crypto::test_utils::test_cipher_suite_provider,
21     extension::ExtensionType,
22     identity::basic::BasicIdentityProvider,
23     identity::test_utils::get_test_signing_identity,
24     key_package::{KeyPackageGeneration, KeyPackageGenerator},
25     mls_rules::{CommitOptions, DefaultMlsRules},
26     tree_kem::{leaf_node::test_utils::get_test_capabilities, Lifetime},
27 };
28 
29 use crate::extension::RequiredCapabilitiesExt;
30 
31 #[cfg(not(feature = "by_ref_proposal"))]
32 use crate::crypto::HpkePublicKey;
33 
34 pub const TEST_GROUP: &[u8] = b"group";
35 
36 #[derive(Clone)]
37 pub(crate) struct TestGroup {
38     pub group: Group<TestClientConfig>,
39 }
40 
41 impl TestGroup {
42     #[cfg(feature = "external_client")]
43     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
propose(&mut self, proposal: Proposal) -> MlsMessage44     pub(crate) async fn propose(&mut self, proposal: Proposal) -> MlsMessage {
45         self.group.proposal_message(proposal, vec![]).await.unwrap()
46     }
47 
48     #[cfg(feature = "by_ref_proposal")]
49     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
update_proposal(&mut self) -> Proposal50     pub(crate) async fn update_proposal(&mut self) -> Proposal {
51         self.group.update_proposal(None, None).await.unwrap()
52     }
53 
54     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
join_with_custom_config<F>( &mut self, name: &str, custom_kp: bool, mut config: F, ) -> Result<(TestGroup, MlsMessage), MlsError> where F: FnMut(&mut TestClientConfig),55     pub(crate) async fn join_with_custom_config<F>(
56         &mut self,
57         name: &str,
58         custom_kp: bool,
59         mut config: F,
60     ) -> Result<(TestGroup, MlsMessage), MlsError>
61     where
62         F: FnMut(&mut TestClientConfig),
63     {
64         let (mut new_client, new_key_package) = if custom_kp {
65             test_client_with_key_pkg_custom(
66                 self.group.protocol_version(),
67                 self.group.cipher_suite(),
68                 name,
69                 &mut config,
70             )
71             .await
72         } else {
73             test_client_with_key_pkg(
74                 self.group.protocol_version(),
75                 self.group.cipher_suite(),
76                 name,
77             )
78             .await
79         };
80 
81         // Add new member to the group
82         let CommitOutput {
83             welcome_messages,
84             ratchet_tree,
85             commit_message,
86             ..
87         } = self
88             .group
89             .commit_builder()
90             .add_member(new_key_package)
91             .unwrap()
92             .build()
93             .await
94             .unwrap();
95 
96         // Apply the commit to the original group
97         self.group.apply_pending_commit().await.unwrap();
98 
99         config(&mut new_client.config);
100 
101         // Group from new member's perspective
102         let (new_group, _) = Group::join(
103             &welcome_messages[0],
104             ratchet_tree,
105             new_client.config.clone(),
106             new_client.signer.clone().unwrap(),
107         )
108         .await?;
109 
110         let new_test_group = TestGroup { group: new_group };
111 
112         Ok((new_test_group, commit_message))
113     }
114 
115     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
join(&mut self, name: &str) -> (TestGroup, MlsMessage)116     pub(crate) async fn join(&mut self, name: &str) -> (TestGroup, MlsMessage) {
117         self.join_with_custom_config(name, false, |_| ())
118             .await
119             .unwrap()
120     }
121 
122     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
process_pending_commit( &mut self, ) -> Result<CommitMessageDescription, MlsError>123     pub(crate) async fn process_pending_commit(
124         &mut self,
125     ) -> Result<CommitMessageDescription, MlsError> {
126         self.group.apply_pending_commit().await
127     }
128 
129     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
process_message( &mut self, message: MlsMessage, ) -> Result<ReceivedMessage, MlsError>130     pub(crate) async fn process_message(
131         &mut self,
132         message: MlsMessage,
133     ) -> Result<ReceivedMessage, MlsError> {
134         self.group.process_incoming_message(message).await
135     }
136 
137     #[cfg(feature = "private_message")]
138     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
make_plaintext(&mut self, content: Content) -> MlsMessage139     pub(crate) async fn make_plaintext(&mut self, content: Content) -> MlsMessage {
140         let auth_content = AuthenticatedContent::new_signed(
141             &self.group.cipher_suite_provider,
142             &self.group.state.context,
143             Sender::Member(*self.group.private_tree.self_index),
144             content,
145             &self.group.signer,
146             WireFormat::PublicMessage,
147             Vec::new(),
148         )
149         .await
150         .unwrap();
151 
152         self.group.format_for_wire(auth_content).await.unwrap()
153     }
154 }
155 
156 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_group_context(epoch: u64, cipher_suite: CipherSuite) -> GroupContext157 pub(crate) async fn get_test_group_context(epoch: u64, cipher_suite: CipherSuite) -> GroupContext {
158     let cs = test_cipher_suite_provider(cipher_suite);
159 
160     GroupContext {
161         protocol_version: TEST_PROTOCOL_VERSION,
162         cipher_suite,
163         group_id: TEST_GROUP.to_vec(),
164         epoch,
165         tree_hash: cs.hash(&[1, 2, 3]).await.unwrap(),
166         confirmed_transcript_hash: cs.hash(&[3, 2, 1]).await.unwrap().into(),
167         extensions: ExtensionList::from(vec![]),
168     }
169 }
170 
171 #[cfg(feature = "prior_epoch")]
get_test_group_context_with_id( group_id: Vec<u8>, epoch: u64, cipher_suite: CipherSuite, ) -> GroupContext172 pub(crate) fn get_test_group_context_with_id(
173     group_id: Vec<u8>,
174     epoch: u64,
175     cipher_suite: CipherSuite,
176 ) -> GroupContext {
177     GroupContext {
178         protocol_version: TEST_PROTOCOL_VERSION,
179         cipher_suite,
180         group_id,
181         epoch,
182         tree_hash: vec![],
183         confirmed_transcript_hash: ConfirmedTranscriptHash::from(vec![]),
184         extensions: ExtensionList::from(vec![]),
185     }
186 }
187 
group_extensions() -> ExtensionList188 pub(crate) fn group_extensions() -> ExtensionList {
189     let required_capabilities = RequiredCapabilitiesExt::default();
190 
191     let mut extensions = ExtensionList::new();
192     extensions.set_from(required_capabilities).unwrap();
193     extensions
194 }
195 
lifetime() -> Lifetime196 pub(crate) fn lifetime() -> Lifetime {
197     Lifetime::years(1).unwrap()
198 }
199 
200 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_member( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, identifier: &[u8], ) -> (KeyPackageGeneration, SignatureSecretKey)201 pub(crate) async fn test_member(
202     protocol_version: ProtocolVersion,
203     cipher_suite: CipherSuite,
204     identifier: &[u8],
205 ) -> (KeyPackageGeneration, SignatureSecretKey) {
206     let (signing_identity, signing_key) = get_test_signing_identity(cipher_suite, identifier).await;
207 
208     let key_package_generator = KeyPackageGenerator {
209         protocol_version,
210         cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
211         signing_identity: &signing_identity,
212         signing_key: &signing_key,
213         identity_provider: &BasicIdentityProvider,
214     };
215 
216     let key_package = key_package_generator
217         .generate(
218             lifetime(),
219             get_test_capabilities(),
220             ExtensionList::default(),
221             ExtensionList::default(),
222         )
223         .await
224         .unwrap();
225 
226     (key_package, signing_key)
227 }
228 
229 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_group_custom( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, extension_types: Vec<ExtensionType>, leaf_extensions: Option<ExtensionList>, commit_options: Option<CommitOptions>, ) -> TestGroup230 pub(crate) async fn test_group_custom(
231     protocol_version: ProtocolVersion,
232     cipher_suite: CipherSuite,
233     extension_types: Vec<ExtensionType>,
234     leaf_extensions: Option<ExtensionList>,
235     commit_options: Option<CommitOptions>,
236 ) -> TestGroup {
237     let leaf_extensions = leaf_extensions.unwrap_or_default();
238     let commit_options = commit_options.unwrap_or_default();
239 
240     let (signing_identity, secret_key) = get_test_signing_identity(cipher_suite, b"member").await;
241 
242     let group = TestClientBuilder::new_for_test()
243         .leaf_node_extensions(leaf_extensions)
244         .mls_rules(DefaultMlsRules::default().with_commit_options(commit_options))
245         .extension_types(extension_types)
246         .protocol_versions(ProtocolVersion::all())
247         .used_protocol_version(protocol_version)
248         .signing_identity(signing_identity.clone(), secret_key, cipher_suite)
249         .build()
250         .create_group_with_id(TEST_GROUP.to_vec(), group_extensions())
251         .await
252         .unwrap();
253 
254     TestGroup { group }
255 }
256 
257 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_group( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, ) -> TestGroup258 pub(crate) async fn test_group(
259     protocol_version: ProtocolVersion,
260     cipher_suite: CipherSuite,
261 ) -> TestGroup {
262     test_group_custom(
263         protocol_version,
264         cipher_suite,
265         Default::default(),
266         None,
267         None,
268     )
269     .await
270 }
271 
272 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_group_custom_config<F>( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, custom: F, ) -> TestGroup where F: FnOnce(TestClientBuilder) -> TestClientBuilder,273 pub(crate) async fn test_group_custom_config<F>(
274     protocol_version: ProtocolVersion,
275     cipher_suite: CipherSuite,
276     custom: F,
277 ) -> TestGroup
278 where
279     F: FnOnce(TestClientBuilder) -> TestClientBuilder,
280 {
281     let (signing_identity, secret_key) = get_test_signing_identity(cipher_suite, b"member").await;
282 
283     let client_builder = TestClientBuilder::new_for_test().used_protocol_version(protocol_version);
284 
285     let group = custom(client_builder)
286         .signing_identity(signing_identity.clone(), secret_key, cipher_suite)
287         .build()
288         .create_group_with_id(TEST_GROUP.to_vec(), group_extensions())
289         .await
290         .unwrap();
291 
292     TestGroup { group }
293 }
294 
295 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_n_member_group( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, num_members: usize, ) -> Vec<TestGroup>296 pub(crate) async fn test_n_member_group(
297     protocol_version: ProtocolVersion,
298     cipher_suite: CipherSuite,
299     num_members: usize,
300 ) -> Vec<TestGroup> {
301     let group = test_group(protocol_version, cipher_suite).await;
302 
303     let mut groups = vec![group];
304 
305     for i in 1..num_members {
306         let (new_group, commit) = groups.get_mut(0).unwrap().join(&format!("name {i}")).await;
307         process_commit(&mut groups, commit, 0).await;
308         groups.push(new_group);
309     }
310 
311     groups
312 }
313 
314 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
process_commit(groups: &mut [TestGroup], commit: MlsMessage, excluded: u32)315 pub(crate) async fn process_commit(groups: &mut [TestGroup], commit: MlsMessage, excluded: u32) {
316     for g in groups
317         .iter_mut()
318         .filter(|g| g.group.current_member_index() != excluded)
319     {
320         g.process_message(commit.clone()).await.unwrap();
321     }
322 }
323 
get_test_25519_key(key_byte: u8) -> HpkePublicKey324 pub(crate) fn get_test_25519_key(key_byte: u8) -> HpkePublicKey {
325     vec![key_byte; 32].into()
326 }
327 
328 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_groups_with_features( n: usize, extensions: ExtensionList, leaf_extensions: ExtensionList, ) -> Vec<Group<TestClientConfig>>329 pub(crate) async fn get_test_groups_with_features(
330     n: usize,
331     extensions: ExtensionList,
332     leaf_extensions: ExtensionList,
333 ) -> Vec<Group<TestClientConfig>> {
334     let mut clients = Vec::new();
335 
336     for i in 0..n {
337         let (identity, secret_key) =
338             get_test_signing_identity(TEST_CIPHER_SUITE, format!("member{i}").as_bytes()).await;
339 
340         clients.push(
341             TestClientBuilder::new_for_test()
342                 .extension_type(999.into())
343                 .leaf_node_extensions(leaf_extensions.clone())
344                 .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
345                 .build(),
346         );
347     }
348 
349     let group = clients[0]
350         .create_group_with_id(b"TEST GROUP".to_vec(), extensions)
351         .await
352         .unwrap();
353 
354     let mut groups = vec![group];
355 
356     for client in clients.iter().skip(1) {
357         let key_package = client.generate_key_package_message().await.unwrap();
358 
359         let commit_output = groups[0]
360             .commit_builder()
361             .add_member(key_package)
362             .unwrap()
363             .build()
364             .await
365             .unwrap();
366 
367         groups[0].apply_pending_commit().await.unwrap();
368 
369         for group in groups.iter_mut().skip(1) {
370             group
371                 .process_incoming_message(commit_output.commit_message.clone())
372                 .await
373                 .unwrap();
374         }
375 
376         groups.push(
377             client
378                 .join_group(None, &commit_output.welcome_messages[0])
379                 .await
380                 .unwrap()
381                 .0,
382         );
383     }
384 
385     groups
386 }
387 
random_bytes(count: usize) -> Vec<u8>388 pub fn random_bytes(count: usize) -> Vec<u8> {
389     let mut buf = vec![0; count];
390     rand::thread_rng().fill_bytes(&mut buf);
391     buf
392 }
393 
394 pub(crate) struct GroupWithoutKeySchedule {
395     inner: Group<TestClientConfig>,
396     pub secrets: Option<(TreeKemPrivate, PathSecret)>,
397     pub provisional_public_state: Option<ProvisionalState>,
398 }
399 
400 impl Deref for GroupWithoutKeySchedule {
401     type Target = Group<TestClientConfig>;
402 
403     #[cfg_attr(coverage_nightly, coverage(off))]
deref(&self) -> &Self::Target404     fn deref(&self) -> &Self::Target {
405         &self.inner
406     }
407 }
408 
409 impl DerefMut for GroupWithoutKeySchedule {
410     #[cfg_attr(coverage_nightly, coverage(off))]
deref_mut(&mut self) -> &mut Self::Target411     fn deref_mut(&mut self) -> &mut Self::Target {
412         &mut self.inner
413     }
414 }
415 
416 #[cfg(feature = "rfc_compliant")]
417 impl GroupWithoutKeySchedule {
418     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
new(cs: CipherSuite) -> Self419     pub async fn new(cs: CipherSuite) -> Self {
420         Self {
421             inner: test_group(TEST_PROTOCOL_VERSION, cs).await.group,
422             secrets: None,
423             provisional_public_state: None,
424         }
425     }
426 }
427 
428 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
429 #[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
430 #[cfg_attr(
431     all(not(target_arch = "wasm32"), mls_build_async),
432     maybe_async::must_be_async
433 )]
434 impl MessageProcessor for GroupWithoutKeySchedule {
435     type CipherSuiteProvider = <Group<TestClientConfig> as MessageProcessor>::CipherSuiteProvider;
436     type OutputType = <Group<TestClientConfig> as MessageProcessor>::OutputType;
437     type PreSharedKeyStorage = <Group<TestClientConfig> as MessageProcessor>::PreSharedKeyStorage;
438     type IdentityProvider = <Group<TestClientConfig> as MessageProcessor>::IdentityProvider;
439     type MlsRules = <Group<TestClientConfig> as MessageProcessor>::MlsRules;
440 
group_state(&self) -> &GroupState441     fn group_state(&self) -> &GroupState {
442         self.inner.group_state()
443     }
444 
445     #[cfg_attr(coverage_nightly, coverage(off))]
group_state_mut(&mut self) -> &mut GroupState446     fn group_state_mut(&mut self) -> &mut GroupState {
447         self.inner.group_state_mut()
448     }
449 
mls_rules(&self) -> Self::MlsRules450     fn mls_rules(&self) -> Self::MlsRules {
451         self.inner.mls_rules()
452     }
453 
identity_provider(&self) -> Self::IdentityProvider454     fn identity_provider(&self) -> Self::IdentityProvider {
455         self.inner.identity_provider()
456     }
457 
cipher_suite_provider(&self) -> &Self::CipherSuiteProvider458     fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider {
459         self.inner.cipher_suite_provider()
460     }
461 
psk_storage(&self) -> Self::PreSharedKeyStorage462     fn psk_storage(&self) -> Self::PreSharedKeyStorage {
463         self.inner.psk_storage()
464     }
465 
can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool466     fn can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool {
467         self.inner.can_continue_processing(provisional_state)
468     }
469 
470     #[cfg(feature = "private_message")]
471     #[cfg_attr(coverage_nightly, coverage(off))]
min_epoch_available(&self) -> Option<u64>472     fn min_epoch_available(&self) -> Option<u64> {
473         self.inner.min_epoch_available()
474     }
475 
apply_update_path( &mut self, sender: LeafIndex, update_path: &ValidatedUpdatePath, provisional_state: &mut ProvisionalState, ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError>476     async fn apply_update_path(
477         &mut self,
478         sender: LeafIndex,
479         update_path: &ValidatedUpdatePath,
480         provisional_state: &mut ProvisionalState,
481     ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
482         self.inner
483             .apply_update_path(sender, update_path, provisional_state)
484             .await
485     }
486 
487     #[cfg(feature = "private_message")]
488     #[cfg_attr(coverage_nightly, coverage(off))]
process_ciphertext( &mut self, cipher_text: &PrivateMessage, ) -> Result<EventOrContent<Self::OutputType>, MlsError>489     async fn process_ciphertext(
490         &mut self,
491         cipher_text: &PrivateMessage,
492     ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
493         self.inner.process_ciphertext(cipher_text).await
494     }
495 
496     #[cfg_attr(coverage_nightly, coverage(off))]
verify_plaintext_authentication( &self, message: PublicMessage, ) -> Result<EventOrContent<Self::OutputType>, MlsError>497     async fn verify_plaintext_authentication(
498         &self,
499         message: PublicMessage,
500     ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
501         self.inner.verify_plaintext_authentication(message).await
502     }
503 
update_key_schedule( &mut self, secrets: Option<(TreeKemPrivate, PathSecret)>, _interim_transcript_hash: InterimTranscriptHash, _confirmation_tag: &ConfirmationTag, provisional_public_state: ProvisionalState, ) -> Result<(), MlsError>504     async fn update_key_schedule(
505         &mut self,
506         secrets: Option<(TreeKemPrivate, PathSecret)>,
507         _interim_transcript_hash: InterimTranscriptHash,
508         _confirmation_tag: &ConfirmationTag,
509         provisional_public_state: ProvisionalState,
510     ) -> Result<(), MlsError> {
511         self.provisional_public_state = Some(provisional_public_state);
512         self.secrets = secrets;
513         Ok(())
514     }
515 
516     #[cfg(feature = "private_message")]
517     #[cfg_attr(coverage_nightly, coverage(off))]
self_index(&self) -> Option<LeafIndex>518     fn self_index(&self) -> Option<LeafIndex> {
519         <Group<TestClientConfig> as MessageProcessor>::self_index(&self.inner)
520     }
521 }
522