1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use alloc::collections::VecDeque; 6 7 #[cfg(target_has_atomic = "ptr")] 8 use alloc::sync::Arc; 9 10 #[cfg(mls_build_async)] 11 use alloc::boxed::Box; 12 use alloc::vec::Vec; 13 use core::{ 14 convert::Infallible, 15 fmt::{self, Debug}, 16 }; 17 use mls_rs_core::group::{EpochRecord, GroupState, GroupStateStorage}; 18 #[cfg(not(target_has_atomic = "ptr"))] 19 use portable_atomic_util::Arc; 20 21 use crate::client::MlsError; 22 23 #[cfg(feature = "std")] 24 use std::collections::{hash_map::Entry, HashMap}; 25 26 #[cfg(not(feature = "std"))] 27 use alloc::collections::{btree_map::Entry, BTreeMap}; 28 29 #[cfg(feature = "std")] 30 use std::sync::Mutex; 31 32 #[cfg(not(feature = "std"))] 33 use spin::Mutex; 34 35 pub(crate) const DEFAULT_EPOCH_RETENTION_LIMIT: usize = 3; 36 37 #[derive(Clone)] 38 pub(crate) struct InMemoryGroupData { 39 pub(crate) state_data: Vec<u8>, 40 pub(crate) epoch_data: VecDeque<EpochRecord>, 41 } 42 43 impl Debug for InMemoryGroupData { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 45 f.debug_struct("InMemoryGroupData") 46 .field( 47 "state_data", 48 &mls_rs_core::debug::pretty_bytes(&self.state_data), 49 ) 50 .field("epoch_data", &self.epoch_data) 51 .finish() 52 } 53 } 54 55 impl InMemoryGroupData { new(state_data: Vec<u8>) -> InMemoryGroupData56 pub fn new(state_data: Vec<u8>) -> InMemoryGroupData { 57 InMemoryGroupData { 58 state_data, 59 epoch_data: Default::default(), 60 } 61 } 62 get_epoch_data_index(&self, epoch_id: u64) -> Option<u64>63 fn get_epoch_data_index(&self, epoch_id: u64) -> Option<u64> { 64 self.epoch_data 65 .front() 66 .and_then(|e| epoch_id.checked_sub(e.id)) 67 } 68 get_epoch(&self, epoch_id: u64) -> Option<&EpochRecord>69 pub fn get_epoch(&self, epoch_id: u64) -> Option<&EpochRecord> { 70 self.get_epoch_data_index(epoch_id) 71 .and_then(|i| self.epoch_data.get(i as usize)) 72 } 73 get_mut_epoch(&mut self, epoch_id: u64) -> Option<&mut EpochRecord>74 pub fn get_mut_epoch(&mut self, epoch_id: u64) -> Option<&mut EpochRecord> { 75 self.get_epoch_data_index(epoch_id) 76 .and_then(|i| self.epoch_data.get_mut(i as usize)) 77 } 78 insert_epoch(&mut self, epoch: EpochRecord)79 pub fn insert_epoch(&mut self, epoch: EpochRecord) { 80 self.epoch_data.push_back(epoch) 81 } 82 83 // This function does not fail if an update can't be made. If the epoch 84 // is not in the store, then it can no longer be accessed by future 85 // get_epoch calls and is no longer relevant. update_epoch(&mut self, epoch: EpochRecord)86 pub fn update_epoch(&mut self, epoch: EpochRecord) { 87 if let Some(existing_epoch) = self.get_mut_epoch(epoch.id) { 88 *existing_epoch = epoch 89 } 90 } 91 trim_epochs(&mut self, max_epoch_retention: usize)92 pub fn trim_epochs(&mut self, max_epoch_retention: usize) { 93 while self.epoch_data.len() > max_epoch_retention { 94 self.epoch_data.pop_front(); 95 } 96 } 97 } 98 99 #[derive(Clone)] 100 /// In memory group state storage backed by a HashMap. 101 /// 102 /// All clones of an instance of this type share the same underlying HashMap. 103 pub struct InMemoryGroupStateStorage { 104 #[cfg(feature = "std")] 105 pub(crate) inner: Arc<Mutex<HashMap<Vec<u8>, InMemoryGroupData>>>, 106 #[cfg(not(feature = "std"))] 107 pub(crate) inner: Arc<Mutex<BTreeMap<Vec<u8>, InMemoryGroupData>>>, 108 pub(crate) max_epoch_retention: usize, 109 } 110 111 impl Debug for InMemoryGroupStateStorage { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 113 f.debug_struct("InMemoryGroupStateStorage") 114 .field( 115 "inner", 116 &mls_rs_core::debug::pretty_with(|f| { 117 f.debug_map() 118 .entries( 119 self.lock() 120 .iter() 121 .map(|(k, v)| (mls_rs_core::debug::pretty_bytes(k), v)), 122 ) 123 .finish() 124 }), 125 ) 126 .field("max_epoch_retention", &self.max_epoch_retention) 127 .finish() 128 } 129 } 130 131 impl InMemoryGroupStateStorage { 132 /// Create an empty group state storage. new() -> Self133 pub fn new() -> Self { 134 Self { 135 inner: Default::default(), 136 max_epoch_retention: DEFAULT_EPOCH_RETENTION_LIMIT, 137 } 138 } 139 with_max_epoch_retention(self, max_epoch_retention: usize) -> Result<Self, MlsError>140 pub fn with_max_epoch_retention(self, max_epoch_retention: usize) -> Result<Self, MlsError> { 141 (max_epoch_retention > 0) 142 .then_some(()) 143 .ok_or(MlsError::NonZeroRetentionRequired)?; 144 145 Ok(Self { 146 inner: self.inner, 147 max_epoch_retention, 148 }) 149 } 150 151 /// Get the set of unique group ids that have data stored. stored_groups(&self) -> Vec<Vec<u8>>152 pub fn stored_groups(&self) -> Vec<Vec<u8>> { 153 self.lock().keys().cloned().collect() 154 } 155 156 /// Delete all data corresponding to `group_id`. delete_group(&self, group_id: &[u8])157 pub fn delete_group(&self, group_id: &[u8]) { 158 self.lock().remove(group_id); 159 } 160 161 #[cfg(feature = "std")] lock(&self) -> std::sync::MutexGuard<'_, HashMap<Vec<u8>, InMemoryGroupData>>162 fn lock(&self) -> std::sync::MutexGuard<'_, HashMap<Vec<u8>, InMemoryGroupData>> { 163 self.inner.lock().unwrap() 164 } 165 166 #[cfg(not(feature = "std"))] lock(&self) -> spin::mutex::MutexGuard<'_, BTreeMap<Vec<u8>, InMemoryGroupData>>167 fn lock(&self) -> spin::mutex::MutexGuard<'_, BTreeMap<Vec<u8>, InMemoryGroupData>> { 168 self.inner.lock() 169 } 170 } 171 172 impl Default for InMemoryGroupStateStorage { default() -> Self173 fn default() -> Self { 174 Self::new() 175 } 176 } 177 178 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] 179 #[cfg_attr(mls_build_async, maybe_async::must_be_async)] 180 impl GroupStateStorage for InMemoryGroupStateStorage { 181 type Error = Infallible; 182 max_epoch_id(&self, group_id: &[u8]) -> Result<Option<u64>, Self::Error>183 async fn max_epoch_id(&self, group_id: &[u8]) -> Result<Option<u64>, Self::Error> { 184 Ok(self 185 .lock() 186 .get(group_id) 187 .and_then(|group_data| group_data.epoch_data.back().map(|e| e.id))) 188 } 189 state(&self, group_id: &[u8]) -> Result<Option<Vec<u8>>, Self::Error>190 async fn state(&self, group_id: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> { 191 Ok(self 192 .lock() 193 .get(group_id) 194 .map(|data| data.state_data.clone())) 195 } 196 epoch(&self, group_id: &[u8], epoch_id: u64) -> Result<Option<Vec<u8>>, Self::Error>197 async fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result<Option<Vec<u8>>, Self::Error> { 198 Ok(self 199 .lock() 200 .get(group_id) 201 .and_then(|data| data.get_epoch(epoch_id).map(|ep| ep.data.clone()))) 202 } 203 write( &mut self, state: GroupState, epoch_inserts: Vec<EpochRecord>, epoch_updates: Vec<EpochRecord>, ) -> Result<(), Self::Error>204 async fn write( 205 &mut self, 206 state: GroupState, 207 epoch_inserts: Vec<EpochRecord>, 208 epoch_updates: Vec<EpochRecord>, 209 ) -> Result<(), Self::Error> { 210 let mut group_map = self.lock(); 211 212 let group_data = match group_map.entry(state.id) { 213 Entry::Occupied(entry) => { 214 let data = entry.into_mut(); 215 data.state_data = state.data; 216 data 217 } 218 Entry::Vacant(entry) => entry.insert(InMemoryGroupData::new(state.data)), 219 }; 220 221 epoch_inserts 222 .into_iter() 223 .for_each(|e| group_data.insert_epoch(e)); 224 225 epoch_updates 226 .into_iter() 227 .for_each(|e| group_data.update_epoch(e)); 228 229 group_data.trim_epochs(self.max_epoch_retention); 230 231 Ok(()) 232 } 233 } 234 235 #[cfg(all(test, feature = "prior_epoch"))] 236 mod tests { 237 use alloc::{format, vec, vec::Vec}; 238 use assert_matches::assert_matches; 239 240 use super::{InMemoryGroupData, InMemoryGroupStateStorage}; 241 use crate::{client::MlsError, group::test_utils::TEST_GROUP}; 242 243 use mls_rs_core::group::{EpochRecord, GroupState, GroupStateStorage}; 244 245 impl InMemoryGroupStateStorage { test_data(&self) -> InMemoryGroupData246 fn test_data(&self) -> InMemoryGroupData { 247 self.lock().get(TEST_GROUP).unwrap().clone() 248 } 249 } 250 test_storage(retention_limit: usize) -> Result<InMemoryGroupStateStorage, MlsError>251 fn test_storage(retention_limit: usize) -> Result<InMemoryGroupStateStorage, MlsError> { 252 InMemoryGroupStateStorage::new().with_max_epoch_retention(retention_limit) 253 } 254 test_epoch(epoch_id: u64) -> EpochRecord255 fn test_epoch(epoch_id: u64) -> EpochRecord { 256 EpochRecord::new(epoch_id, format!("epoch {epoch_id}").as_bytes().to_vec()) 257 } 258 test_snapshot(epoch_id: u64) -> GroupState259 fn test_snapshot(epoch_id: u64) -> GroupState { 260 GroupState { 261 id: TEST_GROUP.into(), 262 data: format!("snapshot {epoch_id}").as_bytes().to_vec(), 263 } 264 } 265 266 #[test] test_zero_max_retention()267 fn test_zero_max_retention() { 268 assert_matches!(test_storage(0), Err(MlsError::NonZeroRetentionRequired)) 269 } 270 271 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] existing_storage_can_have_larger_epoch_count()272 async fn existing_storage_can_have_larger_epoch_count() { 273 let mut storage = test_storage(2).unwrap(); 274 275 let epoch_inserts = vec![test_epoch(0), test_epoch(1)]; 276 277 storage 278 .write(test_snapshot(0), epoch_inserts, Vec::new()) 279 .await 280 .unwrap(); 281 282 assert_eq!(storage.test_data().epoch_data.len(), 2); 283 284 storage.max_epoch_retention = 4; 285 286 let epoch_inserts = vec![test_epoch(3), test_epoch(4)]; 287 288 storage 289 .write(test_snapshot(1), epoch_inserts, Vec::new()) 290 .await 291 .unwrap(); 292 293 assert_eq!(storage.test_data().epoch_data.len(), 4); 294 } 295 296 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] existing_storage_can_have_smaller_epoch_count()297 async fn existing_storage_can_have_smaller_epoch_count() { 298 let mut storage = test_storage(4).unwrap(); 299 300 let epoch_inserts = vec![test_epoch(0), test_epoch(1), test_epoch(3), test_epoch(4)]; 301 302 storage 303 .write(test_snapshot(1), epoch_inserts, Vec::new()) 304 .await 305 .unwrap(); 306 307 assert_eq!(storage.test_data().epoch_data.len(), 4); 308 309 storage.max_epoch_retention = 2; 310 311 let epoch_inserts = vec![test_epoch(5)]; 312 313 storage 314 .write(test_snapshot(1), epoch_inserts, Vec::new()) 315 .await 316 .unwrap(); 317 318 assert_eq!(storage.test_data().epoch_data.len(), 2); 319 } 320 321 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] epoch_insert_over_limit()322 async fn epoch_insert_over_limit() { 323 test_epoch_insert_over_limit(false).await 324 } 325 326 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] epoch_insert_over_limit_with_update()327 async fn epoch_insert_over_limit_with_update() { 328 test_epoch_insert_over_limit(true).await 329 } 330 331 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] test_epoch_insert_over_limit(with_update: bool)332 async fn test_epoch_insert_over_limit(with_update: bool) { 333 let mut storage = test_storage(1).unwrap(); 334 335 let mut epoch_inserts = vec![test_epoch(0), test_epoch(1)]; 336 let updates = with_update 337 .then_some(vec![test_epoch(0)]) 338 .unwrap_or_default(); 339 let snapshot = test_snapshot(1); 340 341 storage 342 .write(snapshot.clone(), epoch_inserts.clone(), updates) 343 .await 344 .unwrap(); 345 346 let stored = storage.test_data(); 347 348 assert_eq!(stored.state_data, snapshot.data); 349 assert_eq!(stored.epoch_data.len(), 1); 350 351 let expected = epoch_inserts.pop().unwrap(); 352 assert_eq!(stored.epoch_data[0], expected); 353 } 354 } 355