1 use mls_rs_codec::MlsEncode;
2 use mls_rs_core::protocol_version::ProtocolVersion;
3
4 use crate::{
5 cipher_suite::CipherSuite,
6 client_builder::{BaseConfig, MlsConfig, WithCryptoProvider, WithIdentityProvider},
7 group::{framing::MlsMessage, Group},
8 identity::basic::BasicIdentityProvider,
9 test_utils::{generate_basic_client, get_test_groups},
10 };
11
12 pub use mls_rs_crypto_openssl::OpensslCryptoProvider as MlsCryptoProvider;
13
14 pub type TestClientConfig =
15 WithIdentityProvider<BasicIdentityProvider, WithCryptoProvider<MlsCryptoProvider, BaseConfig>>;
16
17 macro_rules! load_test_case_mls {
18 ($name:ident, $generate:expr) => {
19 load_test_case_mls!($name, $generate, to_vec_pretty)
20 };
21 ($name:ident, $generate:expr, $to_json:ident) => {{
22 #[cfg(any(target_arch = "wasm32", not(feature = "std")))]
23 {
24 // Do not remove `async`! (The goal of this line is to remove warnings
25 // about `$generate` not being used. Actually calling it will make tests fail.)
26 let _ = async { $generate };
27
28 mls_rs_codec::MlsDecode::mls_decode(&mut &include_bytes!(concat!(
29 env!("CARGO_MANIFEST_DIR"),
30 "/test_data/",
31 stringify!($name),
32 ".mls"
33 )))
34 .unwrap()
35 }
36
37 #[cfg(all(not(target_arch = "wasm32"), feature = "std"))]
38 {
39 let path = concat!(
40 env!("CARGO_MANIFEST_DIR"),
41 "/test_data/",
42 stringify!($name),
43 ".mls"
44 );
45
46 if !std::path::Path::new(path).exists() {
47 std::fs::write(path, $generate.mls_encode_to_vec().unwrap()).unwrap();
48 }
49
50 mls_rs_codec::MlsDecode::mls_decode(&mut std::fs::read(path).unwrap().as_slice())
51 .unwrap()
52 }
53 }};
54 }
55
56 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
generate_test_cases(cs: CipherSuite) -> Vec<MlsMessage>57 async fn generate_test_cases(cs: CipherSuite) -> Vec<MlsMessage> {
58 let mut cases = Vec::new();
59
60 for size in [16, 64, 128] {
61 let group = get_test_groups(
62 ProtocolVersion::MLS_10,
63 cs,
64 size,
65 None,
66 false,
67 &MlsCryptoProvider::new(),
68 )
69 .await
70 .pop()
71 .unwrap();
72
73 let group_info = group
74 .group_info_message_allowing_ext_commit(true)
75 .await
76 .unwrap();
77
78 cases.push(group_info)
79 }
80
81 cases
82 }
83
84 #[derive(Clone)]
85 pub struct GroupStates<C: MlsConfig> {
86 pub sender: Group<C>,
87 pub receiver: Group<C>,
88 }
89
90 #[cfg(mls_build_async)]
load_group_states(cs: CipherSuite) -> Vec<GroupStates<impl MlsConfig>>91 pub fn load_group_states(cs: CipherSuite) -> Vec<GroupStates<impl MlsConfig>> {
92 let group_info = load_test_case_mls!(group_state, block_on(generate_test_cases(cs)), to_vec);
93 join_group(cs, group_info)
94 }
95
96 #[cfg(not(mls_build_async))]
load_group_states(cs: CipherSuite) -> Vec<GroupStates<impl MlsConfig>>97 pub fn load_group_states(cs: CipherSuite) -> Vec<GroupStates<impl MlsConfig>> {
98 let group_infos: Vec<MlsMessage> =
99 load_test_case_mls!(group_state, generate_test_cases(cs), to_vec);
100
101 group_infos
102 .into_iter()
103 .map(|info| join_group(cs, info))
104 .collect()
105 }
106
107 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
join_group(cs: CipherSuite, group_info: MlsMessage) -> GroupStates<impl MlsConfig>108 pub async fn join_group(cs: CipherSuite, group_info: MlsMessage) -> GroupStates<impl MlsConfig> {
109 let client = generate_basic_client(
110 cs,
111 ProtocolVersion::MLS_10,
112 99999999999,
113 None,
114 false,
115 &MlsCryptoProvider::new(),
116 None,
117 );
118
119 let mut sender = client.commit_external(group_info).await.unwrap().0;
120
121 let client = generate_basic_client(
122 cs,
123 ProtocolVersion::MLS_10,
124 99999999998,
125 None,
126 false,
127 &MlsCryptoProvider::new(),
128 None,
129 );
130
131 let group_info = sender
132 .group_info_message_allowing_ext_commit(true)
133 .await
134 .unwrap();
135
136 let (receiver, commit) = client.commit_external(group_info).await.unwrap();
137 sender.process_incoming_message(commit).await.unwrap();
138
139 GroupStates { sender, receiver }
140 }
141