1 use mls_rs_codec::MlsEncode;
2 use mls_rs_core::protocol_version::ProtocolVersion;
3 
4 use crate::{
5     cipher_suite::CipherSuite,
6     client_builder::{BaseConfig, MlsConfig, WithCryptoProvider, WithIdentityProvider},
7     group::{framing::MlsMessage, Group},
8     identity::basic::BasicIdentityProvider,
9     test_utils::{generate_basic_client, get_test_groups},
10 };
11 
12 pub use mls_rs_crypto_openssl::OpensslCryptoProvider as MlsCryptoProvider;
13 
14 pub type TestClientConfig =
15     WithIdentityProvider<BasicIdentityProvider, WithCryptoProvider<MlsCryptoProvider, BaseConfig>>;
16 
17 macro_rules! load_test_case_mls {
18     ($name:ident, $generate:expr) => {
19         load_test_case_mls!($name, $generate, to_vec_pretty)
20     };
21     ($name:ident, $generate:expr, $to_json:ident) => {{
22         #[cfg(any(target_arch = "wasm32", not(feature = "std")))]
23         {
24             // Do not remove `async`! (The goal of this line is to remove warnings
25             // about `$generate` not being used. Actually calling it will make tests fail.)
26             let _ = async { $generate };
27 
28             mls_rs_codec::MlsDecode::mls_decode(&mut &include_bytes!(concat!(
29                 env!("CARGO_MANIFEST_DIR"),
30                 "/test_data/",
31                 stringify!($name),
32                 ".mls"
33             )))
34             .unwrap()
35         }
36 
37         #[cfg(all(not(target_arch = "wasm32"), feature = "std"))]
38         {
39             let path = concat!(
40                 env!("CARGO_MANIFEST_DIR"),
41                 "/test_data/",
42                 stringify!($name),
43                 ".mls"
44             );
45 
46             if !std::path::Path::new(path).exists() {
47                 std::fs::write(path, $generate.mls_encode_to_vec().unwrap()).unwrap();
48             }
49 
50             mls_rs_codec::MlsDecode::mls_decode(&mut std::fs::read(path).unwrap().as_slice())
51                 .unwrap()
52         }
53     }};
54 }
55 
56 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
generate_test_cases(cs: CipherSuite) -> Vec<MlsMessage>57 async fn generate_test_cases(cs: CipherSuite) -> Vec<MlsMessage> {
58     let mut cases = Vec::new();
59 
60     for size in [16, 64, 128] {
61         let group = get_test_groups(
62             ProtocolVersion::MLS_10,
63             cs,
64             size,
65             None,
66             false,
67             &MlsCryptoProvider::new(),
68         )
69         .await
70         .pop()
71         .unwrap();
72 
73         let group_info = group
74             .group_info_message_allowing_ext_commit(true)
75             .await
76             .unwrap();
77 
78         cases.push(group_info)
79     }
80 
81     cases
82 }
83 
84 #[derive(Clone)]
85 pub struct GroupStates<C: MlsConfig> {
86     pub sender: Group<C>,
87     pub receiver: Group<C>,
88 }
89 
90 #[cfg(mls_build_async)]
load_group_states(cs: CipherSuite) -> Vec<GroupStates<impl MlsConfig>>91 pub fn load_group_states(cs: CipherSuite) -> Vec<GroupStates<impl MlsConfig>> {
92     let group_info = load_test_case_mls!(group_state, block_on(generate_test_cases(cs)), to_vec);
93     join_group(cs, group_info)
94 }
95 
96 #[cfg(not(mls_build_async))]
load_group_states(cs: CipherSuite) -> Vec<GroupStates<impl MlsConfig>>97 pub fn load_group_states(cs: CipherSuite) -> Vec<GroupStates<impl MlsConfig>> {
98     let group_infos: Vec<MlsMessage> =
99         load_test_case_mls!(group_state, generate_test_cases(cs), to_vec);
100 
101     group_infos
102         .into_iter()
103         .map(|info| join_group(cs, info))
104         .collect()
105 }
106 
107 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
join_group(cs: CipherSuite, group_info: MlsMessage) -> GroupStates<impl MlsConfig>108 pub async fn join_group(cs: CipherSuite, group_info: MlsMessage) -> GroupStates<impl MlsConfig> {
109     let client = generate_basic_client(
110         cs,
111         ProtocolVersion::MLS_10,
112         99999999999,
113         None,
114         false,
115         &MlsCryptoProvider::new(),
116         None,
117     );
118 
119     let mut sender = client.commit_external(group_info).await.unwrap().0;
120 
121     let client = generate_basic_client(
122         cs,
123         ProtocolVersion::MLS_10,
124         99999999998,
125         None,
126         false,
127         &MlsCryptoProvider::new(),
128         None,
129     );
130 
131     let group_info = sender
132         .group_info_message_allowing_ext_commit(true)
133         .await
134         .unwrap();
135 
136     let (receiver, commit) = client.commit_external(group_info).await.unwrap();
137     sender.process_incoming_message(commit).await.unwrap();
138 
139     GroupStates { sender, receiver }
140 }
141