1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::collections::VecDeque;
6 
7 #[cfg(target_has_atomic = "ptr")]
8 use alloc::sync::Arc;
9 
10 #[cfg(mls_build_async)]
11 use alloc::boxed::Box;
12 use alloc::vec::Vec;
13 use core::{
14     convert::Infallible,
15     fmt::{self, Debug},
16 };
17 use mls_rs_core::group::{EpochRecord, GroupState, GroupStateStorage};
18 #[cfg(not(target_has_atomic = "ptr"))]
19 use portable_atomic_util::Arc;
20 
21 use crate::client::MlsError;
22 
23 #[cfg(feature = "std")]
24 use std::collections::{hash_map::Entry, HashMap};
25 
26 #[cfg(not(feature = "std"))]
27 use alloc::collections::{btree_map::Entry, BTreeMap};
28 
29 #[cfg(feature = "std")]
30 use std::sync::Mutex;
31 
32 #[cfg(not(feature = "std"))]
33 use spin::Mutex;
34 
35 pub(crate) const DEFAULT_EPOCH_RETENTION_LIMIT: usize = 3;
36 
37 #[derive(Clone)]
38 pub(crate) struct InMemoryGroupData {
39     pub(crate) state_data: Vec<u8>,
40     pub(crate) epoch_data: VecDeque<EpochRecord>,
41 }
42 
43 impl Debug for InMemoryGroupData {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result44     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45         f.debug_struct("InMemoryGroupData")
46             .field(
47                 "state_data",
48                 &mls_rs_core::debug::pretty_bytes(&self.state_data),
49             )
50             .field("epoch_data", &self.epoch_data)
51             .finish()
52     }
53 }
54 
55 impl InMemoryGroupData {
new(state_data: Vec<u8>) -> InMemoryGroupData56     pub fn new(state_data: Vec<u8>) -> InMemoryGroupData {
57         InMemoryGroupData {
58             state_data,
59             epoch_data: Default::default(),
60         }
61     }
62 
get_epoch_data_index(&self, epoch_id: u64) -> Option<u64>63     fn get_epoch_data_index(&self, epoch_id: u64) -> Option<u64> {
64         self.epoch_data
65             .front()
66             .and_then(|e| epoch_id.checked_sub(e.id))
67     }
68 
get_epoch(&self, epoch_id: u64) -> Option<&EpochRecord>69     pub fn get_epoch(&self, epoch_id: u64) -> Option<&EpochRecord> {
70         self.get_epoch_data_index(epoch_id)
71             .and_then(|i| self.epoch_data.get(i as usize))
72     }
73 
get_mut_epoch(&mut self, epoch_id: u64) -> Option<&mut EpochRecord>74     pub fn get_mut_epoch(&mut self, epoch_id: u64) -> Option<&mut EpochRecord> {
75         self.get_epoch_data_index(epoch_id)
76             .and_then(|i| self.epoch_data.get_mut(i as usize))
77     }
78 
insert_epoch(&mut self, epoch: EpochRecord)79     pub fn insert_epoch(&mut self, epoch: EpochRecord) {
80         self.epoch_data.push_back(epoch)
81     }
82 
83     // This function does not fail if an update can't be made. If the epoch
84     // is not in the store, then it can no longer be accessed by future
85     // get_epoch calls and is no longer relevant.
update_epoch(&mut self, epoch: EpochRecord)86     pub fn update_epoch(&mut self, epoch: EpochRecord) {
87         if let Some(existing_epoch) = self.get_mut_epoch(epoch.id) {
88             *existing_epoch = epoch
89         }
90     }
91 
trim_epochs(&mut self, max_epoch_retention: usize)92     pub fn trim_epochs(&mut self, max_epoch_retention: usize) {
93         while self.epoch_data.len() > max_epoch_retention {
94             self.epoch_data.pop_front();
95         }
96     }
97 }
98 
99 #[derive(Clone)]
100 /// In memory group state storage backed by a HashMap.
101 ///
102 /// All clones of an instance of this type share the same underlying HashMap.
103 pub struct InMemoryGroupStateStorage {
104     #[cfg(feature = "std")]
105     pub(crate) inner: Arc<Mutex<HashMap<Vec<u8>, InMemoryGroupData>>>,
106     #[cfg(not(feature = "std"))]
107     pub(crate) inner: Arc<Mutex<BTreeMap<Vec<u8>, InMemoryGroupData>>>,
108     pub(crate) max_epoch_retention: usize,
109 }
110 
111 impl Debug for InMemoryGroupStateStorage {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result112     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113         f.debug_struct("InMemoryGroupStateStorage")
114             .field(
115                 "inner",
116                 &mls_rs_core::debug::pretty_with(|f| {
117                     f.debug_map()
118                         .entries(
119                             self.lock()
120                                 .iter()
121                                 .map(|(k, v)| (mls_rs_core::debug::pretty_bytes(k), v)),
122                         )
123                         .finish()
124                 }),
125             )
126             .field("max_epoch_retention", &self.max_epoch_retention)
127             .finish()
128     }
129 }
130 
131 impl InMemoryGroupStateStorage {
132     /// Create an empty group state storage.
new() -> Self133     pub fn new() -> Self {
134         Self {
135             inner: Default::default(),
136             max_epoch_retention: DEFAULT_EPOCH_RETENTION_LIMIT,
137         }
138     }
139 
with_max_epoch_retention(self, max_epoch_retention: usize) -> Result<Self, MlsError>140     pub fn with_max_epoch_retention(self, max_epoch_retention: usize) -> Result<Self, MlsError> {
141         (max_epoch_retention > 0)
142             .then_some(())
143             .ok_or(MlsError::NonZeroRetentionRequired)?;
144 
145         Ok(Self {
146             inner: self.inner,
147             max_epoch_retention,
148         })
149     }
150 
151     /// Get the set of unique group ids that have data stored.
stored_groups(&self) -> Vec<Vec<u8>>152     pub fn stored_groups(&self) -> Vec<Vec<u8>> {
153         self.lock().keys().cloned().collect()
154     }
155 
156     /// Delete all data corresponding to `group_id`.
delete_group(&self, group_id: &[u8])157     pub fn delete_group(&self, group_id: &[u8]) {
158         self.lock().remove(group_id);
159     }
160 
161     #[cfg(feature = "std")]
lock(&self) -> std::sync::MutexGuard<'_, HashMap<Vec<u8>, InMemoryGroupData>>162     fn lock(&self) -> std::sync::MutexGuard<'_, HashMap<Vec<u8>, InMemoryGroupData>> {
163         self.inner.lock().unwrap()
164     }
165 
166     #[cfg(not(feature = "std"))]
lock(&self) -> spin::mutex::MutexGuard<'_, BTreeMap<Vec<u8>, InMemoryGroupData>>167     fn lock(&self) -> spin::mutex::MutexGuard<'_, BTreeMap<Vec<u8>, InMemoryGroupData>> {
168         self.inner.lock()
169     }
170 }
171 
172 impl Default for InMemoryGroupStateStorage {
default() -> Self173     fn default() -> Self {
174         Self::new()
175     }
176 }
177 
178 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
179 #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
180 impl GroupStateStorage for InMemoryGroupStateStorage {
181     type Error = Infallible;
182 
max_epoch_id(&self, group_id: &[u8]) -> Result<Option<u64>, Self::Error>183     async fn max_epoch_id(&self, group_id: &[u8]) -> Result<Option<u64>, Self::Error> {
184         Ok(self
185             .lock()
186             .get(group_id)
187             .and_then(|group_data| group_data.epoch_data.back().map(|e| e.id)))
188     }
189 
state(&self, group_id: &[u8]) -> Result<Option<Vec<u8>>, Self::Error>190     async fn state(&self, group_id: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
191         Ok(self
192             .lock()
193             .get(group_id)
194             .map(|data| data.state_data.clone()))
195     }
196 
epoch(&self, group_id: &[u8], epoch_id: u64) -> Result<Option<Vec<u8>>, Self::Error>197     async fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result<Option<Vec<u8>>, Self::Error> {
198         Ok(self
199             .lock()
200             .get(group_id)
201             .and_then(|data| data.get_epoch(epoch_id).map(|ep| ep.data.clone())))
202     }
203 
write( &mut self, state: GroupState, epoch_inserts: Vec<EpochRecord>, epoch_updates: Vec<EpochRecord>, ) -> Result<(), Self::Error>204     async fn write(
205         &mut self,
206         state: GroupState,
207         epoch_inserts: Vec<EpochRecord>,
208         epoch_updates: Vec<EpochRecord>,
209     ) -> Result<(), Self::Error> {
210         let mut group_map = self.lock();
211 
212         let group_data = match group_map.entry(state.id) {
213             Entry::Occupied(entry) => {
214                 let data = entry.into_mut();
215                 data.state_data = state.data;
216                 data
217             }
218             Entry::Vacant(entry) => entry.insert(InMemoryGroupData::new(state.data)),
219         };
220 
221         epoch_inserts
222             .into_iter()
223             .for_each(|e| group_data.insert_epoch(e));
224 
225         epoch_updates
226             .into_iter()
227             .for_each(|e| group_data.update_epoch(e));
228 
229         group_data.trim_epochs(self.max_epoch_retention);
230 
231         Ok(())
232     }
233 }
234 
235 #[cfg(all(test, feature = "prior_epoch"))]
236 mod tests {
237     use alloc::{format, vec, vec::Vec};
238     use assert_matches::assert_matches;
239 
240     use super::{InMemoryGroupData, InMemoryGroupStateStorage};
241     use crate::{client::MlsError, group::test_utils::TEST_GROUP};
242 
243     use mls_rs_core::group::{EpochRecord, GroupState, GroupStateStorage};
244 
245     impl InMemoryGroupStateStorage {
test_data(&self) -> InMemoryGroupData246         fn test_data(&self) -> InMemoryGroupData {
247             self.lock().get(TEST_GROUP).unwrap().clone()
248         }
249     }
250 
test_storage(retention_limit: usize) -> Result<InMemoryGroupStateStorage, MlsError>251     fn test_storage(retention_limit: usize) -> Result<InMemoryGroupStateStorage, MlsError> {
252         InMemoryGroupStateStorage::new().with_max_epoch_retention(retention_limit)
253     }
254 
test_epoch(epoch_id: u64) -> EpochRecord255     fn test_epoch(epoch_id: u64) -> EpochRecord {
256         EpochRecord::new(epoch_id, format!("epoch {epoch_id}").as_bytes().to_vec())
257     }
258 
test_snapshot(epoch_id: u64) -> GroupState259     fn test_snapshot(epoch_id: u64) -> GroupState {
260         GroupState {
261             id: TEST_GROUP.into(),
262             data: format!("snapshot {epoch_id}").as_bytes().to_vec(),
263         }
264     }
265 
266     #[test]
test_zero_max_retention()267     fn test_zero_max_retention() {
268         assert_matches!(test_storage(0), Err(MlsError::NonZeroRetentionRequired))
269     }
270 
271     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
existing_storage_can_have_larger_epoch_count()272     async fn existing_storage_can_have_larger_epoch_count() {
273         let mut storage = test_storage(2).unwrap();
274 
275         let epoch_inserts = vec![test_epoch(0), test_epoch(1)];
276 
277         storage
278             .write(test_snapshot(0), epoch_inserts, Vec::new())
279             .await
280             .unwrap();
281 
282         assert_eq!(storage.test_data().epoch_data.len(), 2);
283 
284         storage.max_epoch_retention = 4;
285 
286         let epoch_inserts = vec![test_epoch(3), test_epoch(4)];
287 
288         storage
289             .write(test_snapshot(1), epoch_inserts, Vec::new())
290             .await
291             .unwrap();
292 
293         assert_eq!(storage.test_data().epoch_data.len(), 4);
294     }
295 
296     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
existing_storage_can_have_smaller_epoch_count()297     async fn existing_storage_can_have_smaller_epoch_count() {
298         let mut storage = test_storage(4).unwrap();
299 
300         let epoch_inserts = vec![test_epoch(0), test_epoch(1), test_epoch(3), test_epoch(4)];
301 
302         storage
303             .write(test_snapshot(1), epoch_inserts, Vec::new())
304             .await
305             .unwrap();
306 
307         assert_eq!(storage.test_data().epoch_data.len(), 4);
308 
309         storage.max_epoch_retention = 2;
310 
311         let epoch_inserts = vec![test_epoch(5)];
312 
313         storage
314             .write(test_snapshot(1), epoch_inserts, Vec::new())
315             .await
316             .unwrap();
317 
318         assert_eq!(storage.test_data().epoch_data.len(), 2);
319     }
320 
321     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
epoch_insert_over_limit()322     async fn epoch_insert_over_limit() {
323         test_epoch_insert_over_limit(false).await
324     }
325 
326     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
epoch_insert_over_limit_with_update()327     async fn epoch_insert_over_limit_with_update() {
328         test_epoch_insert_over_limit(true).await
329     }
330 
331     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_epoch_insert_over_limit(with_update: bool)332     async fn test_epoch_insert_over_limit(with_update: bool) {
333         let mut storage = test_storage(1).unwrap();
334 
335         let mut epoch_inserts = vec![test_epoch(0), test_epoch(1)];
336         let updates = with_update
337             .then_some(vec![test_epoch(0)])
338             .unwrap_or_default();
339         let snapshot = test_snapshot(1);
340 
341         storage
342             .write(snapshot.clone(), epoch_inserts.clone(), updates)
343             .await
344             .unwrap();
345 
346         let stored = storage.test_data();
347 
348         assert_eq!(stored.state_data, snapshot.data);
349         assert_eq!(stored.epoch_data.len(), 1);
350 
351         let expected = epoch_inserts.pop().unwrap();
352         assert_eq!(stored.epoch_data[0], expected);
353     }
354 }
355