1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use self::{
6     message_key::MessageKey,
7     reuse_guard::ReuseGuard,
8     sender_data_key::{SenderData, SenderDataAAD, SenderDataKey},
9 };
10 
11 use super::{
12     epoch::EpochSecrets,
13     framing::{ContentType, FramedContent, Sender, WireFormat},
14     message_signature::AuthenticatedContent,
15     padding::PaddingMode,
16     secret_tree::{KeyType, MessageKeyData},
17     GroupContext,
18 };
19 use crate::{
20     client::MlsError,
21     tree_kem::node::{LeafIndex, NodeIndex},
22 };
23 use mls_rs_codec::MlsEncode;
24 use mls_rs_core::{crypto::CipherSuiteProvider, error::IntoAnyError};
25 use zeroize::Zeroizing;
26 
27 mod message_key;
28 mod reuse_guard;
29 mod sender_data_key;
30 
31 #[cfg(feature = "private_message")]
32 use super::framing::{PrivateContentAAD, PrivateMessage, PrivateMessageContent};
33 
34 #[cfg(test)]
35 pub use sender_data_key::test_utils::*;
36 
37 pub(crate) trait GroupStateProvider {
group_context(&self) -> &GroupContext38     fn group_context(&self) -> &GroupContext;
self_index(&self) -> LeafIndex39     fn self_index(&self) -> LeafIndex;
epoch_secrets_mut(&mut self) -> &mut EpochSecrets40     fn epoch_secrets_mut(&mut self) -> &mut EpochSecrets;
epoch_secrets(&self) -> &EpochSecrets41     fn epoch_secrets(&self) -> &EpochSecrets;
42 }
43 
44 pub(crate) struct CiphertextProcessor<'a, GS, CP>
45 where
46     GS: GroupStateProvider,
47     CP: CipherSuiteProvider,
48 {
49     group_state: &'a mut GS,
50     cipher_suite_provider: CP,
51 }
52 
53 impl<'a, GS, CP> CiphertextProcessor<'a, GS, CP>
54 where
55     GS: GroupStateProvider,
56     CP: CipherSuiteProvider,
57 {
new( group_state: &'a mut GS, cipher_suite_provider: CP, ) -> CiphertextProcessor<'a, GS, CP>58     pub fn new(
59         group_state: &'a mut GS,
60         cipher_suite_provider: CP,
61     ) -> CiphertextProcessor<'a, GS, CP> {
62         Self {
63             group_state,
64             cipher_suite_provider,
65         }
66     }
67 
68     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
next_encryption_key( &mut self, key_type: KeyType, ) -> Result<MessageKeyData, MlsError>69     pub async fn next_encryption_key(
70         &mut self,
71         key_type: KeyType,
72     ) -> Result<MessageKeyData, MlsError> {
73         let self_index = NodeIndex::from(self.group_state.self_index());
74 
75         self.group_state
76             .epoch_secrets_mut()
77             .secret_tree
78             .next_message_key(&self.cipher_suite_provider, self_index, key_type)
79             .await
80     }
81 
82     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
decryption_key( &mut self, sender: LeafIndex, key_type: KeyType, generation: u32, ) -> Result<MessageKeyData, MlsError>83     pub async fn decryption_key(
84         &mut self,
85         sender: LeafIndex,
86         key_type: KeyType,
87         generation: u32,
88     ) -> Result<MessageKeyData, MlsError> {
89         let sender = NodeIndex::from(sender);
90 
91         self.group_state
92             .epoch_secrets_mut()
93             .secret_tree
94             .message_key_generation(&self.cipher_suite_provider, sender, key_type, generation)
95             .await
96     }
97 
98     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
seal( &mut self, auth_content: AuthenticatedContent, padding: PaddingMode, ) -> Result<PrivateMessage, MlsError>99     pub async fn seal(
100         &mut self,
101         auth_content: AuthenticatedContent,
102         padding: PaddingMode,
103     ) -> Result<PrivateMessage, MlsError> {
104         if Sender::Member(*self.group_state.self_index()) != auth_content.content.sender {
105             return Err(MlsError::InvalidSender);
106         }
107 
108         let content_type = ContentType::from(&auth_content.content.content);
109         let authenticated_data = auth_content.content.authenticated_data;
110 
111         // Build a ciphertext content using the plaintext content and signature
112         let private_content = PrivateMessageContent {
113             content: auth_content.content.content,
114             auth: auth_content.auth,
115         };
116 
117         // Build ciphertext aad using the plaintext message
118         let aad = PrivateContentAAD {
119             group_id: auth_content.content.group_id,
120             epoch: auth_content.content.epoch,
121             content_type,
122             authenticated_data: authenticated_data.clone(),
123         };
124 
125         // Generate a 4 byte reuse guard
126         let reuse_guard = ReuseGuard::random(&self.cipher_suite_provider)
127             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
128 
129         // Grab an encryption key from the current epoch's key schedule
130         let key_type = match &content_type {
131             ContentType::Application => KeyType::Application,
132             _ => KeyType::Handshake,
133         };
134 
135         let mut serialized_private_content = private_content.mls_encode_to_vec()?;
136 
137         // Apply padding to private content based on the current padding mode.
138         serialized_private_content.resize(padding.padded_size(serialized_private_content.len()), 0);
139 
140         let serialized_private_content = Zeroizing::new(serialized_private_content);
141 
142         // Encrypt the ciphertext content using the encryption key and a nonce that is
143         // reuse safe by xor the reuse guard with the first 4 bytes
144         let self_index = self.group_state.self_index();
145 
146         let key_data = self.next_encryption_key(key_type).await?;
147         let generation = key_data.generation;
148 
149         let ciphertext = MessageKey::new(key_data)
150             .encrypt(
151                 &self.cipher_suite_provider,
152                 &serialized_private_content,
153                 &aad.mls_encode_to_vec()?,
154                 &reuse_guard,
155             )
156             .await
157             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
158 
159         // Construct an mls sender data struct using the plaintext sender info, the generation
160         // of the key schedule encryption key, and the reuse guard used to encrypt ciphertext
161         let sender_data = SenderData {
162             sender: self_index,
163             generation,
164             reuse_guard,
165         };
166 
167         let sender_data_aad = SenderDataAAD {
168             group_id: self.group_state.group_context().group_id.clone(),
169             epoch: self.group_state.group_context().epoch,
170             content_type,
171         };
172 
173         // Encrypt the sender data with the derived sender_key and sender_nonce from the current
174         // epoch's key schedule
175         let sender_data_key = SenderDataKey::new(
176             &self.group_state.epoch_secrets().sender_data_secret,
177             &ciphertext,
178             &self.cipher_suite_provider,
179         )
180         .await?;
181 
182         let encrypted_sender_data = sender_data_key.seal(&sender_data, &sender_data_aad).await?;
183 
184         Ok(PrivateMessage {
185             group_id: self.group_state.group_context().group_id.clone(),
186             epoch: self.group_state.group_context().epoch,
187             content_type,
188             authenticated_data,
189             encrypted_sender_data,
190             ciphertext,
191         })
192     }
193 
194     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
open( &mut self, ciphertext: &PrivateMessage, ) -> Result<AuthenticatedContent, MlsError>195     pub async fn open(
196         &mut self,
197         ciphertext: &PrivateMessage,
198     ) -> Result<AuthenticatedContent, MlsError> {
199         // Decrypt the sender data with the derived sender_key and sender_nonce from the message
200         // epoch's key schedule
201         let sender_data_aad = SenderDataAAD {
202             group_id: self.group_state.group_context().group_id.clone(),
203             epoch: self.group_state.group_context().epoch,
204             content_type: ciphertext.content_type,
205         };
206 
207         let sender_data_key = SenderDataKey::new(
208             &self.group_state.epoch_secrets().sender_data_secret,
209             &ciphertext.ciphertext,
210             &self.cipher_suite_provider,
211         )
212         .await?;
213 
214         let sender_data = sender_data_key
215             .open(&ciphertext.encrypted_sender_data, &sender_data_aad)
216             .await?;
217 
218         if self.group_state.self_index() == sender_data.sender {
219             return Err(MlsError::CantProcessMessageFromSelf);
220         }
221 
222         // Grab a decryption key from the message epoch's key schedule
223         let key_type = match &ciphertext.content_type {
224             ContentType::Application => KeyType::Application,
225             _ => KeyType::Handshake,
226         };
227 
228         // Decrypt the content of the message using the grabbed key
229         let key = self
230             .decryption_key(sender_data.sender, key_type, sender_data.generation)
231             .await?;
232 
233         let sender = Sender::Member(*sender_data.sender);
234 
235         let decrypted_content = MessageKey::new(key)
236             .decrypt(
237                 &self.cipher_suite_provider,
238                 &ciphertext.ciphertext,
239                 &PrivateContentAAD::from(ciphertext).mls_encode_to_vec()?,
240                 &sender_data.reuse_guard,
241             )
242             .await
243             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
244 
245         let ciphertext_content =
246             PrivateMessageContent::mls_decode(&mut &**decrypted_content, ciphertext.content_type)?;
247 
248         // Build the MLS plaintext object and process it
249         let auth_content = AuthenticatedContent {
250             wire_format: WireFormat::PrivateMessage,
251             content: FramedContent {
252                 group_id: ciphertext.group_id.clone(),
253                 epoch: ciphertext.epoch,
254                 sender,
255                 authenticated_data: ciphertext.authenticated_data.clone(),
256                 content: ciphertext_content.content,
257             },
258             auth: ciphertext_content.auth,
259         };
260 
261         Ok(auth_content)
262     }
263 }
264 
265 #[cfg(test)]
266 mod test {
267     use crate::{
268         cipher_suite::CipherSuite,
269         client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
270         crypto::{
271             test_utils::{test_cipher_suite_provider, TestCryptoProvider},
272             CipherSuiteProvider,
273         },
274         group::{
275             framing::{ApplicationData, Content, Sender, WireFormat},
276             message_signature::AuthenticatedContent,
277             padding::PaddingMode,
278             test_utils::{random_bytes, test_group, TestGroup},
279         },
280         tree_kem::node::LeafIndex,
281     };
282 
283     use super::{CiphertextProcessor, GroupStateProvider, MlsError};
284 
285     use alloc::vec;
286     use assert_matches::assert_matches;
287 
288     struct TestData {
289         group: TestGroup,
290         content: AuthenticatedContent,
291     }
292 
test_processor( group: &mut TestGroup, cipher_suite: CipherSuite, ) -> CiphertextProcessor<'_, impl GroupStateProvider, impl CipherSuiteProvider>293     fn test_processor(
294         group: &mut TestGroup,
295         cipher_suite: CipherSuite,
296     ) -> CiphertextProcessor<'_, impl GroupStateProvider, impl CipherSuiteProvider> {
297         CiphertextProcessor::new(&mut group.group, test_cipher_suite_provider(cipher_suite))
298     }
299 
300     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_data(cipher_suite: CipherSuite) -> TestData301     async fn test_data(cipher_suite: CipherSuite) -> TestData {
302         let provider = test_cipher_suite_provider(cipher_suite);
303 
304         let group = test_group(TEST_PROTOCOL_VERSION, cipher_suite).await;
305 
306         let content = AuthenticatedContent::new_signed(
307             &provider,
308             group.group.context(),
309             Sender::Member(0),
310             Content::Application(ApplicationData::from(b"test".to_vec())),
311             &group.group.signer,
312             WireFormat::PrivateMessage,
313             vec![],
314         )
315         .await
316         .unwrap();
317 
318         TestData { group, content }
319     }
320 
321     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_encrypt_decrypt()322     async fn test_encrypt_decrypt() {
323         for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
324             let mut test_data = test_data(cipher_suite).await;
325             let mut receiver_group = test_data.group.clone();
326 
327             let mut ciphertext_processor = test_processor(&mut test_data.group, cipher_suite);
328 
329             let ciphertext = ciphertext_processor
330                 .seal(test_data.content.clone(), PaddingMode::StepFunction)
331                 .await
332                 .unwrap();
333 
334             receiver_group.group.private_tree.self_index = LeafIndex::new(1);
335 
336             let mut receiver_processor = test_processor(&mut receiver_group, cipher_suite);
337 
338             let decrypted = receiver_processor.open(&ciphertext).await.unwrap();
339 
340             assert_eq!(decrypted, test_data.content);
341         }
342     }
343 
344     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_padding_use()345     async fn test_padding_use() {
346         let mut test_data = test_data(TEST_CIPHER_SUITE).await;
347         let mut ciphertext_processor = test_processor(&mut test_data.group, TEST_CIPHER_SUITE);
348 
349         let ciphertext_step = ciphertext_processor
350             .seal(test_data.content.clone(), PaddingMode::StepFunction)
351             .await
352             .unwrap();
353 
354         let ciphertext_no_pad = ciphertext_processor
355             .seal(test_data.content.clone(), PaddingMode::None)
356             .await
357             .unwrap();
358 
359         assert!(ciphertext_step.ciphertext.len() > ciphertext_no_pad.ciphertext.len());
360     }
361 
362     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_invalid_sender()363     async fn test_invalid_sender() {
364         let mut test_data = test_data(TEST_CIPHER_SUITE).await;
365         test_data.content.content.sender = Sender::Member(3);
366 
367         let mut ciphertext_processor = test_processor(&mut test_data.group, TEST_CIPHER_SUITE);
368 
369         let res = ciphertext_processor
370             .seal(test_data.content, PaddingMode::None)
371             .await;
372 
373         assert_matches!(res, Err(MlsError::InvalidSender))
374     }
375 
376     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_cant_process_from_self()377     async fn test_cant_process_from_self() {
378         let mut test_data = test_data(TEST_CIPHER_SUITE).await;
379 
380         let mut ciphertext_processor = test_processor(&mut test_data.group, TEST_CIPHER_SUITE);
381 
382         let ciphertext = ciphertext_processor
383             .seal(test_data.content, PaddingMode::None)
384             .await
385             .unwrap();
386 
387         let res = ciphertext_processor.open(&ciphertext).await;
388 
389         assert_matches!(res, Err(MlsError::CantProcessMessageFromSelf))
390     }
391 
392     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_decryption_error()393     async fn test_decryption_error() {
394         let mut test_data = test_data(TEST_CIPHER_SUITE).await;
395         let mut receiver_group = test_data.group.clone();
396         let mut ciphertext_processor = test_processor(&mut test_data.group, TEST_CIPHER_SUITE);
397 
398         let mut ciphertext = ciphertext_processor
399             .seal(test_data.content.clone(), PaddingMode::StepFunction)
400             .await
401             .unwrap();
402 
403         ciphertext.ciphertext = random_bytes(ciphertext.ciphertext.len());
404         receiver_group.group.private_tree.self_index = LeafIndex::new(1);
405 
406         let res = ciphertext_processor.open(&ciphertext).await;
407 
408         assert!(res.is_err());
409     }
410 }
411