1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use crate::client::MlsError;
6 use crate::{group::PriorEpoch, key_package::KeyPackageRef};
7 
8 use alloc::collections::VecDeque;
9 use alloc::vec::Vec;
10 use core::fmt::{self, Debug};
11 use mls_rs_codec::{MlsDecode, MlsEncode};
12 use mls_rs_core::group::{EpochRecord, GroupState};
13 use mls_rs_core::{error::IntoAnyError, group::GroupStateStorage, key_package::KeyPackageStorage};
14 
15 use super::snapshot::Snapshot;
16 
17 #[cfg(feature = "psk")]
18 use crate::group::ResumptionPsk;
19 
20 #[cfg(feature = "psk")]
21 use mls_rs_core::psk::PreSharedKey;
22 
23 /// A set of changes to apply to a GroupStateStorage implementation. These changes MUST
24 /// be made in a single transaction to avoid creating invalid states.
25 #[derive(Default, Clone, Debug)]
26 struct EpochStorageCommit {
27     pub(crate) inserts: VecDeque<PriorEpoch>,
28     pub(crate) updates: Vec<PriorEpoch>,
29 }
30 
31 #[derive(Clone)]
32 pub(crate) struct GroupStateRepository<S, K>
33 where
34     S: GroupStateStorage,
35     K: KeyPackageStorage,
36 {
37     pending_commit: EpochStorageCommit,
38     pending_key_package_removal: Option<KeyPackageRef>,
39     group_id: Vec<u8>,
40     storage: S,
41     key_package_repo: K,
42 }
43 
44 impl<S, K> Debug for GroupStateRepository<S, K>
45 where
46     S: GroupStateStorage + Debug,
47     K: KeyPackageStorage + Debug,
48 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result49     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50         f.debug_struct("GroupStateRepository")
51             .field("pending_commit", &self.pending_commit)
52             .field(
53                 "pending_key_package_removal",
54                 &self.pending_key_package_removal,
55             )
56             .field(
57                 "group_id",
58                 &mls_rs_core::debug::pretty_group_id(&self.group_id),
59             )
60             .field("storage", &self.storage)
61             .field("key_package_repo", &self.key_package_repo)
62             .finish()
63     }
64 }
65 
66 impl<S, K> GroupStateRepository<S, K>
67 where
68     S: GroupStateStorage,
69     K: KeyPackageStorage,
70 {
new( group_id: Vec<u8>, storage: S, key_package_repo: K, key_package_to_remove: Option<KeyPackageRef>, ) -> Result<GroupStateRepository<S, K>, MlsError>71     pub fn new(
72         group_id: Vec<u8>,
73         storage: S,
74         key_package_repo: K,
75         // Set to `None` if restoring from snapshot; set to `Some` when joining a group.
76         key_package_to_remove: Option<KeyPackageRef>,
77     ) -> Result<GroupStateRepository<S, K>, MlsError> {
78         Ok(GroupStateRepository {
79             group_id,
80             storage,
81             pending_key_package_removal: key_package_to_remove,
82             pending_commit: Default::default(),
83             key_package_repo,
84         })
85     }
86 
87     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
find_max_id(&self) -> Result<Option<u64>, MlsError>88     async fn find_max_id(&self) -> Result<Option<u64>, MlsError> {
89         if let Some(max) = self.pending_commit.inserts.back().map(|e| e.epoch_id()) {
90             Ok(Some(max))
91         } else {
92             self.storage
93                 .max_epoch_id(&self.group_id)
94                 .await
95                 .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))
96         }
97     }
98 
99     #[cfg(feature = "psk")]
100     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
resumption_secret( &self, psk_id: &ResumptionPsk, ) -> Result<Option<PreSharedKey>, MlsError>101     pub async fn resumption_secret(
102         &self,
103         psk_id: &ResumptionPsk,
104     ) -> Result<Option<PreSharedKey>, MlsError> {
105         // Search the local inserts cache
106         if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) {
107             if psk_id.psk_epoch >= min {
108                 return Ok(self
109                     .pending_commit
110                     .inserts
111                     .get((psk_id.psk_epoch - min) as usize)
112                     .map(|e| e.secrets.resumption_secret.clone()));
113             }
114         }
115 
116         // Search the local updates cache
117         let maybe_pending = self.find_pending(psk_id.psk_epoch);
118 
119         if let Some(pending) = maybe_pending {
120             return Ok(Some(
121                 self.pending_commit.updates[pending]
122                     .secrets
123                     .resumption_secret
124                     .clone(),
125             ));
126         }
127 
128         // Search the stored cache
129         self.storage
130             .epoch(&psk_id.psk_group_id.0, psk_id.psk_epoch)
131             .await
132             .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
133             .map(|e| Ok(PriorEpoch::mls_decode(&mut &*e)?.secrets.resumption_secret))
134             .transpose()
135     }
136 
137     #[cfg(feature = "private_message")]
138     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_epoch_mut( &mut self, epoch_id: u64, ) -> Result<Option<&mut PriorEpoch>, MlsError>139     pub async fn get_epoch_mut(
140         &mut self,
141         epoch_id: u64,
142     ) -> Result<Option<&mut PriorEpoch>, MlsError> {
143         // Search the local inserts cache
144         if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) {
145             if epoch_id >= min {
146                 return Ok(self
147                     .pending_commit
148                     .inserts
149                     .get_mut((epoch_id - min) as usize));
150             }
151         }
152 
153         // Look in the cached updates map, and if not found look in disk storage
154         // and insert into the updates map for future caching
155         match self.find_pending(epoch_id) {
156             Some(i) => self.pending_commit.updates.get_mut(i).map(Ok),
157             None => self
158                 .storage
159                 .epoch(&self.group_id, epoch_id)
160                 .await
161                 .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
162                 .and_then(|epoch| {
163                     PriorEpoch::mls_decode(&mut &*epoch)
164                         .map(|epoch| {
165                             self.pending_commit.updates.push(epoch);
166                             self.pending_commit.updates.last_mut()
167                         })
168                         .transpose()
169                 }),
170         }
171         .transpose()
172         .map_err(Into::into)
173     }
174 
175     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
insert(&mut self, epoch: PriorEpoch) -> Result<(), MlsError>176     pub async fn insert(&mut self, epoch: PriorEpoch) -> Result<(), MlsError> {
177         if epoch.group_id() != self.group_id {
178             return Err(MlsError::GroupIdMismatch);
179         }
180 
181         let epoch_id = epoch.epoch_id();
182 
183         if let Some(expected_id) = self.find_max_id().await?.map(|id| id + 1) {
184             if epoch_id != expected_id {
185                 return Err(MlsError::InvalidEpoch);
186             }
187         }
188 
189         self.pending_commit.inserts.push_back(epoch);
190 
191         Ok(())
192     }
193 
194     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError>195     pub async fn write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError> {
196         let inserts = self
197             .pending_commit
198             .inserts
199             .iter()
200             .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?)))
201             .collect::<Result<_, MlsError>>()?;
202 
203         let updates = self
204             .pending_commit
205             .updates
206             .iter()
207             .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?)))
208             .collect::<Result<_, MlsError>>()?;
209 
210         let group_state = GroupState {
211             data: group_snapshot.mls_encode_to_vec()?,
212             id: group_snapshot.state.context.group_id,
213         };
214 
215         self.storage
216             .write(group_state, inserts, updates)
217             .await
218             .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?;
219 
220         if let Some(ref key_package_ref) = self.pending_key_package_removal {
221             self.key_package_repo
222                 .delete(key_package_ref)
223                 .await
224                 .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))?;
225         }
226 
227         self.pending_commit.inserts.clear();
228         self.pending_commit.updates.clear();
229 
230         Ok(())
231     }
232 
233     #[cfg(any(feature = "psk", feature = "private_message"))]
find_pending(&self, epoch_id: u64) -> Option<usize>234     fn find_pending(&self, epoch_id: u64) -> Option<usize> {
235         self.pending_commit
236             .updates
237             .iter()
238             .position(|ep| ep.context.epoch == epoch_id)
239     }
240 }
241 
242 #[cfg(test)]
243 mod tests {
244     use alloc::vec;
245     use mls_rs_codec::MlsEncode;
246 
247     use crate::{
248         client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
249         group::{
250             epoch::{test_utils::get_test_epoch_with_id, SenderDataSecret},
251             test_utils::{random_bytes, test_member, TEST_GROUP},
252             PskGroupId, ResumptionPSKUsage,
253         },
254         storage_provider::in_memory::{InMemoryGroupStateStorage, InMemoryKeyPackageStorage},
255     };
256 
257     use super::*;
258 
test_group_state_repo( retention_limit: usize, ) -> GroupStateRepository<InMemoryGroupStateStorage, InMemoryKeyPackageStorage>259     fn test_group_state_repo(
260         retention_limit: usize,
261     ) -> GroupStateRepository<InMemoryGroupStateStorage, InMemoryKeyPackageStorage> {
262         GroupStateRepository::new(
263             TEST_GROUP.to_vec(),
264             InMemoryGroupStateStorage::new()
265                 .with_max_epoch_retention(retention_limit)
266                 .unwrap(),
267             InMemoryKeyPackageStorage::default(),
268             None,
269         )
270         .unwrap()
271     }
272 
test_epoch(epoch_id: u64) -> PriorEpoch273     fn test_epoch(epoch_id: u64) -> PriorEpoch {
274         get_test_epoch_with_id(TEST_GROUP.to_vec(), TEST_CIPHER_SUITE, epoch_id)
275     }
276 
277     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_snapshot(epoch_id: u64) -> Snapshot278     async fn test_snapshot(epoch_id: u64) -> Snapshot {
279         crate::group::snapshot::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, epoch_id).await
280     }
281 
282     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_epoch_inserts()283     async fn test_epoch_inserts() {
284         let mut test_repo = test_group_state_repo(1);
285         let test_epoch = test_epoch(0);
286 
287         test_repo.insert(test_epoch.clone()).await.unwrap();
288 
289         // Check the in-memory state
290         assert_eq!(
291             test_repo.pending_commit.inserts.back().unwrap(),
292             &test_epoch
293         );
294 
295         assert!(test_repo.pending_commit.updates.is_empty());
296 
297         #[cfg(feature = "std")]
298         assert!(test_repo.storage.inner.lock().unwrap().is_empty());
299         #[cfg(not(feature = "std"))]
300         assert!(test_repo.storage.inner.lock().is_empty());
301 
302         let psk_id = ResumptionPsk {
303             psk_epoch: 0,
304             psk_group_id: PskGroupId(test_repo.group_id.clone()),
305             usage: ResumptionPSKUsage::Application,
306         };
307 
308         // Make sure you can recall an epoch sitting as a pending insert
309         let resumption = test_repo.resumption_secret(&psk_id).await.unwrap();
310         let prior_epoch = test_repo.get_epoch_mut(0).await.unwrap().cloned();
311 
312         assert_eq!(
313             prior_epoch.clone().unwrap().secrets.resumption_secret,
314             resumption.unwrap()
315         );
316 
317         assert_eq!(prior_epoch.unwrap(), test_epoch);
318 
319         // Write to the storage
320         let snapshot = test_snapshot(test_epoch.epoch_id()).await;
321         test_repo.write_to_storage(snapshot.clone()).await.unwrap();
322 
323         // Make sure the memory cache cleared
324         assert!(test_repo.pending_commit.inserts.is_empty());
325         assert!(test_repo.pending_commit.updates.is_empty());
326 
327         // Make sure the storage was written
328         #[cfg(feature = "std")]
329         let storage = test_repo.storage.inner.lock().unwrap();
330         #[cfg(not(feature = "std"))]
331         let storage = test_repo.storage.inner.lock();
332 
333         assert_eq!(storage.len(), 1);
334 
335         let stored = storage.get(TEST_GROUP).unwrap();
336 
337         assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap());
338 
339         assert_eq!(stored.epoch_data.len(), 1);
340 
341         assert_eq!(
342             stored.epoch_data.back().unwrap(),
343             &EpochRecord::new(
344                 test_epoch.epoch_id(),
345                 test_epoch.mls_encode_to_vec().unwrap()
346             )
347         );
348     }
349 
350     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_updates()351     async fn test_updates() {
352         let mut test_repo = test_group_state_repo(2);
353         let test_epoch_0 = test_epoch(0);
354 
355         test_repo.insert(test_epoch_0.clone()).await.unwrap();
356 
357         test_repo
358             .write_to_storage(test_snapshot(0).await)
359             .await
360             .unwrap();
361 
362         // Update the stored epoch
363         let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap();
364         assert_eq!(to_update, &test_epoch_0);
365 
366         let new_sender_secret = random_bytes(32);
367         to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret);
368         let to_update = to_update.clone();
369 
370         assert_eq!(test_repo.pending_commit.updates.len(), 1);
371         assert!(test_repo.pending_commit.inserts.is_empty());
372 
373         assert_eq!(
374             test_repo.pending_commit.updates.first().unwrap(),
375             &to_update
376         );
377 
378         // Make sure you can access an epoch pending update
379         let psk_id = ResumptionPsk {
380             psk_epoch: 0,
381             psk_group_id: PskGroupId(test_repo.group_id.clone()),
382             usage: ResumptionPSKUsage::Application,
383         };
384 
385         let owned = test_repo.resumption_secret(&psk_id).await.unwrap();
386         assert_eq!(owned.as_ref(), Some(&to_update.secrets.resumption_secret));
387 
388         // Write the update to storage
389         let snapshot = test_snapshot(1).await;
390         test_repo.write_to_storage(snapshot.clone()).await.unwrap();
391 
392         assert!(test_repo.pending_commit.updates.is_empty());
393         assert!(test_repo.pending_commit.inserts.is_empty());
394 
395         // Make sure the storage was written
396         #[cfg(feature = "std")]
397         let storage = test_repo.storage.inner.lock().unwrap();
398         #[cfg(not(feature = "std"))]
399         let storage = test_repo.storage.inner.lock();
400 
401         assert_eq!(storage.len(), 1);
402 
403         let stored = storage.get(TEST_GROUP).unwrap();
404 
405         assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap());
406 
407         assert_eq!(stored.epoch_data.len(), 1);
408 
409         assert_eq!(
410             stored.epoch_data.back().unwrap(),
411             &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap())
412         );
413     }
414 
415     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_insert_and_update()416     async fn test_insert_and_update() {
417         let mut test_repo = test_group_state_repo(2);
418         let test_epoch_0 = test_epoch(0);
419 
420         test_repo.insert(test_epoch_0).await.unwrap();
421 
422         test_repo
423             .write_to_storage(test_snapshot(0).await)
424             .await
425             .unwrap();
426 
427         // Update the stored epoch
428         let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap();
429         let new_sender_secret = random_bytes(32);
430         to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret);
431         let to_update = to_update.clone();
432 
433         // Insert another epoch
434         let test_epoch_1 = test_epoch(1);
435         test_repo.insert(test_epoch_1.clone()).await.unwrap();
436 
437         test_repo
438             .write_to_storage(test_snapshot(1).await)
439             .await
440             .unwrap();
441 
442         assert!(test_repo.pending_commit.inserts.is_empty());
443         assert!(test_repo.pending_commit.updates.is_empty());
444 
445         // Make sure the storage was written
446         #[cfg(feature = "std")]
447         let storage = test_repo.storage.inner.lock().unwrap();
448         #[cfg(not(feature = "std"))]
449         let storage = test_repo.storage.inner.lock();
450 
451         assert_eq!(storage.len(), 1);
452 
453         let stored = storage.get(TEST_GROUP).unwrap();
454 
455         assert_eq!(stored.epoch_data.len(), 2);
456 
457         assert_eq!(
458             stored.epoch_data.front().unwrap(),
459             &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap())
460         );
461 
462         assert_eq!(
463             stored.epoch_data.back().unwrap(),
464             &EpochRecord::new(
465                 test_epoch_1.epoch_id(),
466                 test_epoch_1.mls_encode_to_vec().unwrap()
467             )
468         );
469     }
470 
471     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_many_epochs_in_storage()472     async fn test_many_epochs_in_storage() {
473         let epochs = (0..10).map(test_epoch).collect::<Vec<_>>();
474 
475         let mut test_repo = test_group_state_repo(10);
476 
477         for epoch in epochs.iter().cloned() {
478             test_repo.insert(epoch).await.unwrap()
479         }
480 
481         test_repo
482             .write_to_storage(test_snapshot(9).await)
483             .await
484             .unwrap();
485 
486         for mut epoch in epochs {
487             let res = test_repo.get_epoch_mut(epoch.epoch_id()).await.unwrap();
488 
489             assert_eq!(res, Some(&mut epoch));
490         }
491     }
492 
493     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_stored_groups_list()494     async fn test_stored_groups_list() {
495         let mut test_repo = test_group_state_repo(2);
496         let test_epoch_0 = test_epoch(0);
497 
498         test_repo.insert(test_epoch_0.clone()).await.unwrap();
499 
500         test_repo
501             .write_to_storage(test_snapshot(0).await)
502             .await
503             .unwrap();
504 
505         assert_eq!(
506             test_repo.storage.stored_groups(),
507             vec![test_epoch_0.context.group_id]
508         )
509     }
510 
511     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
reducing_retention_limit_takes_effect_on_epoch_access()512     async fn reducing_retention_limit_takes_effect_on_epoch_access() {
513         let mut repo = test_group_state_repo(1);
514 
515         repo.insert(test_epoch(0)).await.unwrap();
516         repo.insert(test_epoch(1)).await.unwrap();
517 
518         repo.write_to_storage(test_snapshot(0).await).await.unwrap();
519 
520         let mut repo = GroupStateRepository {
521             storage: repo.storage,
522             ..test_group_state_repo(1)
523         };
524 
525         let res = repo.get_epoch_mut(0).await.unwrap();
526 
527         assert!(res.is_none());
528     }
529 
530     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
in_memory_storage_obeys_retention_limit_after_saving()531     async fn in_memory_storage_obeys_retention_limit_after_saving() {
532         let mut repo = test_group_state_repo(1);
533 
534         repo.insert(test_epoch(0)).await.unwrap();
535         repo.write_to_storage(test_snapshot(0).await).await.unwrap();
536         repo.insert(test_epoch(1)).await.unwrap();
537         repo.write_to_storage(test_snapshot(1).await).await.unwrap();
538 
539         #[cfg(feature = "std")]
540         let lock = repo.storage.inner.lock().unwrap();
541         #[cfg(not(feature = "std"))]
542         let lock = repo.storage.inner.lock();
543 
544         assert_eq!(lock.get(TEST_GROUP).unwrap().epoch_data.len(), 1);
545     }
546 
547     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
used_key_package_is_deleted()548     async fn used_key_package_is_deleted() {
549         let key_package_repo = InMemoryKeyPackageStorage::default();
550 
551         let key_package = test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"member")
552             .await
553             .0;
554 
555         let (id, data) = key_package.to_storage().unwrap();
556 
557         key_package_repo.insert(id, data);
558 
559         let mut repo = GroupStateRepository::new(
560             TEST_GROUP.to_vec(),
561             InMemoryGroupStateStorage::new(),
562             key_package_repo,
563             Some(key_package.reference.clone()),
564         )
565         .unwrap();
566 
567         repo.key_package_repo.get(&key_package.reference).unwrap();
568 
569         repo.write_to_storage(test_snapshot(4).await).await.unwrap();
570 
571         assert!(repo.key_package_repo.get(&key_package.reference).is_none());
572     }
573 }
574