1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 #[cfg(all(feature = "benchmark_util", not(mls_build_async)))]
6 pub mod benchmarks;
7 
8 #[cfg(all(feature = "fuzz_util", not(mls_build_async)))]
9 pub mod fuzz_tests;
10 
11 use mls_rs_core::{
12     crypto::{CipherSuite, CipherSuiteProvider, CryptoProvider},
13     identity::{BasicCredential, Credential, SigningIdentity},
14     protocol_version::ProtocolVersion,
15     psk::ExternalPskId,
16 };
17 
18 use crate::{
19     client_builder::{ClientBuilder, MlsConfig},
20     identity::basic::BasicIdentityProvider,
21     mls_rules::{CommitOptions, DefaultMlsRules},
22     tree_kem::Lifetime,
23     Client, Group, MlsMessage,
24 };
25 
26 #[cfg(feature = "private_message")]
27 use crate::group::{mls_rules::EncryptionOptions, padding::PaddingMode};
28 
29 use alloc::{vec, vec::Vec};
30 
31 #[cfg_attr(coverage_nightly, coverage(off))]
get_test_basic_credential(identity: Vec<u8>) -> Credential32 pub fn get_test_basic_credential(identity: Vec<u8>) -> Credential {
33     BasicCredential::new(identity).into_credential()
34 }
35 
36 pub const TEST_EXT_PSK_ID: &[u8] = b"external psk";
37 
38 #[cfg_attr(coverage_nightly, coverage(off))]
make_test_ext_psk() -> Vec<u8>39 pub fn make_test_ext_psk() -> Vec<u8> {
40     b"secret psk key".to_vec()
41 }
42 
is_edwards(cs: u16) -> bool43 pub fn is_edwards(cs: u16) -> bool {
44     [
45         CipherSuite::CURVE25519_AES128,
46         CipherSuite::CURVE25519_CHACHA,
47         CipherSuite::CURVE448_AES256,
48         CipherSuite::CURVE448_CHACHA,
49     ]
50     .contains(&cs.into())
51 }
52 
53 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
54 #[cfg_attr(coverage_nightly, coverage(off))]
generate_basic_client<C: CryptoProvider + Clone>( cipher_suite: CipherSuite, protocol_version: ProtocolVersion, id: usize, commit_options: Option<CommitOptions>, #[cfg(feature = "private_message")] encrypt_controls: bool, #[cfg(not(feature = "private_message"))] _encrypt_controls: bool, crypto: &C, lifetime: Option<Lifetime>, ) -> Client<impl MlsConfig>55 pub async fn generate_basic_client<C: CryptoProvider + Clone>(
56     cipher_suite: CipherSuite,
57     protocol_version: ProtocolVersion,
58     id: usize,
59     commit_options: Option<CommitOptions>,
60     #[cfg(feature = "private_message")] encrypt_controls: bool,
61     #[cfg(not(feature = "private_message"))] _encrypt_controls: bool,
62     crypto: &C,
63     lifetime: Option<Lifetime>,
64 ) -> Client<impl MlsConfig> {
65     let cs = crypto.cipher_suite_provider(cipher_suite).unwrap();
66 
67     let (secret_key, public_key) = cs.signature_key_generate().await.unwrap();
68     let credential = get_test_basic_credential(alloc::format!("{id}").into_bytes());
69 
70     let identity = SigningIdentity::new(credential, public_key);
71 
72     let mls_rules =
73         DefaultMlsRules::default().with_commit_options(commit_options.unwrap_or_default());
74 
75     #[cfg(feature = "private_message")]
76     let mls_rules = if encrypt_controls {
77         mls_rules.with_encryption_options(EncryptionOptions::new(true, PaddingMode::None))
78     } else {
79         mls_rules
80     };
81 
82     let mut builder = ClientBuilder::new()
83         .crypto_provider(crypto.clone())
84         .identity_provider(BasicIdentityProvider::new())
85         .mls_rules(mls_rules)
86         .psk(
87             ExternalPskId::new(TEST_EXT_PSK_ID.to_vec()),
88             make_test_ext_psk().into(),
89         )
90         .used_protocol_version(protocol_version)
91         .signing_identity(identity, secret_key, cipher_suite);
92 
93     if let Some(lifetime) = lifetime {
94         builder = builder
95             .key_package_lifetime(lifetime.not_after - lifetime.not_before)
96             .key_package_not_before(lifetime.not_before);
97     }
98 
99     builder.build()
100 }
101 
102 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
103 #[cfg_attr(coverage_nightly, coverage(off))]
get_test_groups<C: CryptoProvider + Clone>( version: ProtocolVersion, cipher_suite: CipherSuite, num_participants: usize, commit_options: Option<CommitOptions>, encrypt_controls: bool, crypto: &C, ) -> Vec<Group<impl MlsConfig>>104 pub async fn get_test_groups<C: CryptoProvider + Clone>(
105     version: ProtocolVersion,
106     cipher_suite: CipherSuite,
107     num_participants: usize,
108     commit_options: Option<CommitOptions>,
109     encrypt_controls: bool,
110     crypto: &C,
111 ) -> Vec<Group<impl MlsConfig>> {
112     // Create the group with Alice as the group initiator
113     let creator = generate_basic_client(
114         cipher_suite,
115         version,
116         0,
117         commit_options,
118         encrypt_controls,
119         crypto,
120         None,
121     )
122     .await;
123 
124     let mut creator_group = creator.create_group(Default::default()).await.unwrap();
125 
126     let mut receiver_clients = Vec::new();
127     let mut commit_builder = creator_group.commit_builder();
128 
129     for i in 1..num_participants {
130         let client = generate_basic_client(
131             cipher_suite,
132             version,
133             i,
134             commit_options,
135             encrypt_controls,
136             crypto,
137             None,
138         )
139         .await;
140         let kp = client.generate_key_package_message().await.unwrap();
141 
142         receiver_clients.push(client);
143         commit_builder = commit_builder.add_member(kp.clone()).unwrap();
144     }
145 
146     let welcome = commit_builder.build().await.unwrap().welcome_messages;
147 
148     creator_group.apply_pending_commit().await.unwrap();
149 
150     let tree_data = creator_group.export_tree().into_owned();
151 
152     let mut groups = vec![creator_group];
153 
154     for client in &receiver_clients {
155         let (test_client, _info) = client
156             .join_group(Some(tree_data.clone()), &welcome[0])
157             .await
158             .unwrap();
159 
160         groups.push(test_client);
161     }
162 
163     groups
164 }
165 
166 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
167 #[cfg_attr(coverage_nightly, coverage(off))]
all_process_message<C: MlsConfig>( groups: &mut [Group<C>], message: &MlsMessage, sender: usize, is_commit: bool, )168 pub async fn all_process_message<C: MlsConfig>(
169     groups: &mut [Group<C>],
170     message: &MlsMessage,
171     sender: usize,
172     is_commit: bool,
173 ) {
174     for group in groups {
175         if sender != group.current_member_index() as usize {
176             group
177                 .process_incoming_message(message.clone())
178                 .await
179                 .unwrap();
180         } else if is_commit {
181             group.apply_pending_commit().await.unwrap();
182         }
183     }
184 }
185