1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5 #[cfg(all(feature = "benchmark_util", not(mls_build_async)))]
6 pub mod benchmarks;
7
8 #[cfg(all(feature = "fuzz_util", not(mls_build_async)))]
9 pub mod fuzz_tests;
10
11 use mls_rs_core::{
12 crypto::{CipherSuite, CipherSuiteProvider, CryptoProvider},
13 identity::{BasicCredential, Credential, SigningIdentity},
14 protocol_version::ProtocolVersion,
15 psk::ExternalPskId,
16 };
17
18 use crate::{
19 client_builder::{ClientBuilder, MlsConfig},
20 identity::basic::BasicIdentityProvider,
21 mls_rules::{CommitOptions, DefaultMlsRules},
22 tree_kem::Lifetime,
23 Client, Group, MlsMessage,
24 };
25
26 #[cfg(feature = "private_message")]
27 use crate::group::{mls_rules::EncryptionOptions, padding::PaddingMode};
28
29 use alloc::{vec, vec::Vec};
30
31 #[cfg_attr(coverage_nightly, coverage(off))]
get_test_basic_credential(identity: Vec<u8>) -> Credential32 pub fn get_test_basic_credential(identity: Vec<u8>) -> Credential {
33 BasicCredential::new(identity).into_credential()
34 }
35
36 pub const TEST_EXT_PSK_ID: &[u8] = b"external psk";
37
38 #[cfg_attr(coverage_nightly, coverage(off))]
make_test_ext_psk() -> Vec<u8>39 pub fn make_test_ext_psk() -> Vec<u8> {
40 b"secret psk key".to_vec()
41 }
42
is_edwards(cs: u16) -> bool43 pub fn is_edwards(cs: u16) -> bool {
44 [
45 CipherSuite::CURVE25519_AES128,
46 CipherSuite::CURVE25519_CHACHA,
47 CipherSuite::CURVE448_AES256,
48 CipherSuite::CURVE448_CHACHA,
49 ]
50 .contains(&cs.into())
51 }
52
53 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
54 #[cfg_attr(coverage_nightly, coverage(off))]
generate_basic_client<C: CryptoProvider + Clone>( cipher_suite: CipherSuite, protocol_version: ProtocolVersion, id: usize, commit_options: Option<CommitOptions>, #[cfg(feature = "private_message")] encrypt_controls: bool, #[cfg(not(feature = "private_message"))] _encrypt_controls: bool, crypto: &C, lifetime: Option<Lifetime>, ) -> Client<impl MlsConfig>55 pub async fn generate_basic_client<C: CryptoProvider + Clone>(
56 cipher_suite: CipherSuite,
57 protocol_version: ProtocolVersion,
58 id: usize,
59 commit_options: Option<CommitOptions>,
60 #[cfg(feature = "private_message")] encrypt_controls: bool,
61 #[cfg(not(feature = "private_message"))] _encrypt_controls: bool,
62 crypto: &C,
63 lifetime: Option<Lifetime>,
64 ) -> Client<impl MlsConfig> {
65 let cs = crypto.cipher_suite_provider(cipher_suite).unwrap();
66
67 let (secret_key, public_key) = cs.signature_key_generate().await.unwrap();
68 let credential = get_test_basic_credential(alloc::format!("{id}").into_bytes());
69
70 let identity = SigningIdentity::new(credential, public_key);
71
72 let mls_rules =
73 DefaultMlsRules::default().with_commit_options(commit_options.unwrap_or_default());
74
75 #[cfg(feature = "private_message")]
76 let mls_rules = if encrypt_controls {
77 mls_rules.with_encryption_options(EncryptionOptions::new(true, PaddingMode::None))
78 } else {
79 mls_rules
80 };
81
82 let mut builder = ClientBuilder::new()
83 .crypto_provider(crypto.clone())
84 .identity_provider(BasicIdentityProvider::new())
85 .mls_rules(mls_rules)
86 .psk(
87 ExternalPskId::new(TEST_EXT_PSK_ID.to_vec()),
88 make_test_ext_psk().into(),
89 )
90 .used_protocol_version(protocol_version)
91 .signing_identity(identity, secret_key, cipher_suite);
92
93 if let Some(lifetime) = lifetime {
94 builder = builder
95 .key_package_lifetime(lifetime.not_after - lifetime.not_before)
96 .key_package_not_before(lifetime.not_before);
97 }
98
99 builder.build()
100 }
101
102 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
103 #[cfg_attr(coverage_nightly, coverage(off))]
get_test_groups<C: CryptoProvider + Clone>( version: ProtocolVersion, cipher_suite: CipherSuite, num_participants: usize, commit_options: Option<CommitOptions>, encrypt_controls: bool, crypto: &C, ) -> Vec<Group<impl MlsConfig>>104 pub async fn get_test_groups<C: CryptoProvider + Clone>(
105 version: ProtocolVersion,
106 cipher_suite: CipherSuite,
107 num_participants: usize,
108 commit_options: Option<CommitOptions>,
109 encrypt_controls: bool,
110 crypto: &C,
111 ) -> Vec<Group<impl MlsConfig>> {
112 // Create the group with Alice as the group initiator
113 let creator = generate_basic_client(
114 cipher_suite,
115 version,
116 0,
117 commit_options,
118 encrypt_controls,
119 crypto,
120 None,
121 )
122 .await;
123
124 let mut creator_group = creator.create_group(Default::default()).await.unwrap();
125
126 let mut receiver_clients = Vec::new();
127 let mut commit_builder = creator_group.commit_builder();
128
129 for i in 1..num_participants {
130 let client = generate_basic_client(
131 cipher_suite,
132 version,
133 i,
134 commit_options,
135 encrypt_controls,
136 crypto,
137 None,
138 )
139 .await;
140 let kp = client.generate_key_package_message().await.unwrap();
141
142 receiver_clients.push(client);
143 commit_builder = commit_builder.add_member(kp.clone()).unwrap();
144 }
145
146 let welcome = commit_builder.build().await.unwrap().welcome_messages;
147
148 creator_group.apply_pending_commit().await.unwrap();
149
150 let tree_data = creator_group.export_tree().into_owned();
151
152 let mut groups = vec![creator_group];
153
154 for client in &receiver_clients {
155 let (test_client, _info) = client
156 .join_group(Some(tree_data.clone()), &welcome[0])
157 .await
158 .unwrap();
159
160 groups.push(test_client);
161 }
162
163 groups
164 }
165
166 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
167 #[cfg_attr(coverage_nightly, coverage(off))]
all_process_message<C: MlsConfig>( groups: &mut [Group<C>], message: &MlsMessage, sender: usize, is_commit: bool, )168 pub async fn all_process_message<C: MlsConfig>(
169 groups: &mut [Group<C>],
170 message: &MlsMessage,
171 sender: usize,
172 is_commit: bool,
173 ) {
174 for group in groups {
175 if sender != group.current_member_index() as usize {
176 group
177 .process_incoming_message(message.clone())
178 .await
179 .unwrap();
180 } else if is_commit {
181 group.apply_pending_commit().await.unwrap();
182 }
183 }
184 }
185