1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use crate::client::MlsError; 6 use crate::{group::PriorEpoch, key_package::KeyPackageRef}; 7 8 use alloc::collections::VecDeque; 9 use alloc::vec::Vec; 10 use core::fmt::{self, Debug}; 11 use mls_rs_codec::{MlsDecode, MlsEncode}; 12 use mls_rs_core::group::{EpochRecord, GroupState}; 13 use mls_rs_core::{error::IntoAnyError, group::GroupStateStorage, key_package::KeyPackageStorage}; 14 15 use super::snapshot::Snapshot; 16 17 #[cfg(feature = "psk")] 18 use crate::group::ResumptionPsk; 19 20 #[cfg(feature = "psk")] 21 use mls_rs_core::psk::PreSharedKey; 22 23 /// A set of changes to apply to a GroupStateStorage implementation. These changes MUST 24 /// be made in a single transaction to avoid creating invalid states. 25 #[derive(Default, Clone, Debug)] 26 struct EpochStorageCommit { 27 pub(crate) inserts: VecDeque<PriorEpoch>, 28 pub(crate) updates: Vec<PriorEpoch>, 29 } 30 31 #[derive(Clone)] 32 pub(crate) struct GroupStateRepository<S, K> 33 where 34 S: GroupStateStorage, 35 K: KeyPackageStorage, 36 { 37 pending_commit: EpochStorageCommit, 38 pending_key_package_removal: Option<KeyPackageRef>, 39 group_id: Vec<u8>, 40 storage: S, 41 key_package_repo: K, 42 } 43 44 impl<S, K> Debug for GroupStateRepository<S, K> 45 where 46 S: GroupStateStorage + Debug, 47 K: KeyPackageStorage + Debug, 48 { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 50 f.debug_struct("GroupStateRepository") 51 .field("pending_commit", &self.pending_commit) 52 .field( 53 "pending_key_package_removal", 54 &self.pending_key_package_removal, 55 ) 56 .field( 57 "group_id", 58 &mls_rs_core::debug::pretty_group_id(&self.group_id), 59 ) 60 .field("storage", &self.storage) 61 .field("key_package_repo", &self.key_package_repo) 62 .finish() 63 } 64 } 65 66 impl<S, K> GroupStateRepository<S, K> 67 where 68 S: GroupStateStorage, 69 K: KeyPackageStorage, 70 { new( group_id: Vec<u8>, storage: S, key_package_repo: K, key_package_to_remove: Option<KeyPackageRef>, ) -> Result<GroupStateRepository<S, K>, MlsError>71 pub fn new( 72 group_id: Vec<u8>, 73 storage: S, 74 key_package_repo: K, 75 // Set to `None` if restoring from snapshot; set to `Some` when joining a group. 76 key_package_to_remove: Option<KeyPackageRef>, 77 ) -> Result<GroupStateRepository<S, K>, MlsError> { 78 Ok(GroupStateRepository { 79 group_id, 80 storage, 81 pending_key_package_removal: key_package_to_remove, 82 pending_commit: Default::default(), 83 key_package_repo, 84 }) 85 } 86 87 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] find_max_id(&self) -> Result<Option<u64>, MlsError>88 async fn find_max_id(&self) -> Result<Option<u64>, MlsError> { 89 if let Some(max) = self.pending_commit.inserts.back().map(|e| e.epoch_id()) { 90 Ok(Some(max)) 91 } else { 92 self.storage 93 .max_epoch_id(&self.group_id) 94 .await 95 .map_err(|e| MlsError::GroupStorageError(e.into_any_error())) 96 } 97 } 98 99 #[cfg(feature = "psk")] 100 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] resumption_secret( &self, psk_id: &ResumptionPsk, ) -> Result<Option<PreSharedKey>, MlsError>101 pub async fn resumption_secret( 102 &self, 103 psk_id: &ResumptionPsk, 104 ) -> Result<Option<PreSharedKey>, MlsError> { 105 // Search the local inserts cache 106 if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) { 107 if psk_id.psk_epoch >= min { 108 return Ok(self 109 .pending_commit 110 .inserts 111 .get((psk_id.psk_epoch - min) as usize) 112 .map(|e| e.secrets.resumption_secret.clone())); 113 } 114 } 115 116 // Search the local updates cache 117 let maybe_pending = self.find_pending(psk_id.psk_epoch); 118 119 if let Some(pending) = maybe_pending { 120 return Ok(Some( 121 self.pending_commit.updates[pending] 122 .secrets 123 .resumption_secret 124 .clone(), 125 )); 126 } 127 128 // Search the stored cache 129 self.storage 130 .epoch(&psk_id.psk_group_id.0, psk_id.psk_epoch) 131 .await 132 .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))? 133 .map(|e| Ok(PriorEpoch::mls_decode(&mut &*e)?.secrets.resumption_secret)) 134 .transpose() 135 } 136 137 #[cfg(feature = "private_message")] 138 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] get_epoch_mut( &mut self, epoch_id: u64, ) -> Result<Option<&mut PriorEpoch>, MlsError>139 pub async fn get_epoch_mut( 140 &mut self, 141 epoch_id: u64, 142 ) -> Result<Option<&mut PriorEpoch>, MlsError> { 143 // Search the local inserts cache 144 if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) { 145 if epoch_id >= min { 146 return Ok(self 147 .pending_commit 148 .inserts 149 .get_mut((epoch_id - min) as usize)); 150 } 151 } 152 153 // Look in the cached updates map, and if not found look in disk storage 154 // and insert into the updates map for future caching 155 match self.find_pending(epoch_id) { 156 Some(i) => self.pending_commit.updates.get_mut(i).map(Ok), 157 None => self 158 .storage 159 .epoch(&self.group_id, epoch_id) 160 .await 161 .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))? 162 .and_then(|epoch| { 163 PriorEpoch::mls_decode(&mut &*epoch) 164 .map(|epoch| { 165 self.pending_commit.updates.push(epoch); 166 self.pending_commit.updates.last_mut() 167 }) 168 .transpose() 169 }), 170 } 171 .transpose() 172 .map_err(Into::into) 173 } 174 175 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] insert(&mut self, epoch: PriorEpoch) -> Result<(), MlsError>176 pub async fn insert(&mut self, epoch: PriorEpoch) -> Result<(), MlsError> { 177 if epoch.group_id() != self.group_id { 178 return Err(MlsError::GroupIdMismatch); 179 } 180 181 let epoch_id = epoch.epoch_id(); 182 183 if let Some(expected_id) = self.find_max_id().await?.map(|id| id + 1) { 184 if epoch_id != expected_id { 185 return Err(MlsError::InvalidEpoch); 186 } 187 } 188 189 self.pending_commit.inserts.push_back(epoch); 190 191 Ok(()) 192 } 193 194 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError>195 pub async fn write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError> { 196 let inserts = self 197 .pending_commit 198 .inserts 199 .iter() 200 .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?))) 201 .collect::<Result<_, MlsError>>()?; 202 203 let updates = self 204 .pending_commit 205 .updates 206 .iter() 207 .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?))) 208 .collect::<Result<_, MlsError>>()?; 209 210 let group_state = GroupState { 211 data: group_snapshot.mls_encode_to_vec()?, 212 id: group_snapshot.state.context.group_id, 213 }; 214 215 self.storage 216 .write(group_state, inserts, updates) 217 .await 218 .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?; 219 220 if let Some(ref key_package_ref) = self.pending_key_package_removal { 221 self.key_package_repo 222 .delete(key_package_ref) 223 .await 224 .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))?; 225 } 226 227 self.pending_commit.inserts.clear(); 228 self.pending_commit.updates.clear(); 229 230 Ok(()) 231 } 232 233 #[cfg(any(feature = "psk", feature = "private_message"))] find_pending(&self, epoch_id: u64) -> Option<usize>234 fn find_pending(&self, epoch_id: u64) -> Option<usize> { 235 self.pending_commit 236 .updates 237 .iter() 238 .position(|ep| ep.context.epoch == epoch_id) 239 } 240 } 241 242 #[cfg(test)] 243 mod tests { 244 use alloc::vec; 245 use mls_rs_codec::MlsEncode; 246 247 use crate::{ 248 client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION}, 249 group::{ 250 epoch::{test_utils::get_test_epoch_with_id, SenderDataSecret}, 251 test_utils::{random_bytes, test_member, TEST_GROUP}, 252 PskGroupId, ResumptionPSKUsage, 253 }, 254 storage_provider::in_memory::{InMemoryGroupStateStorage, InMemoryKeyPackageStorage}, 255 }; 256 257 use super::*; 258 test_group_state_repo( retention_limit: usize, ) -> GroupStateRepository<InMemoryGroupStateStorage, InMemoryKeyPackageStorage>259 fn test_group_state_repo( 260 retention_limit: usize, 261 ) -> GroupStateRepository<InMemoryGroupStateStorage, InMemoryKeyPackageStorage> { 262 GroupStateRepository::new( 263 TEST_GROUP.to_vec(), 264 InMemoryGroupStateStorage::new() 265 .with_max_epoch_retention(retention_limit) 266 .unwrap(), 267 InMemoryKeyPackageStorage::default(), 268 None, 269 ) 270 .unwrap() 271 } 272 test_epoch(epoch_id: u64) -> PriorEpoch273 fn test_epoch(epoch_id: u64) -> PriorEpoch { 274 get_test_epoch_with_id(TEST_GROUP.to_vec(), TEST_CIPHER_SUITE, epoch_id) 275 } 276 277 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] test_snapshot(epoch_id: u64) -> Snapshot278 async fn test_snapshot(epoch_id: u64) -> Snapshot { 279 crate::group::snapshot::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, epoch_id).await 280 } 281 282 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_epoch_inserts()283 async fn test_epoch_inserts() { 284 let mut test_repo = test_group_state_repo(1); 285 let test_epoch = test_epoch(0); 286 287 test_repo.insert(test_epoch.clone()).await.unwrap(); 288 289 // Check the in-memory state 290 assert_eq!( 291 test_repo.pending_commit.inserts.back().unwrap(), 292 &test_epoch 293 ); 294 295 assert!(test_repo.pending_commit.updates.is_empty()); 296 297 #[cfg(feature = "std")] 298 assert!(test_repo.storage.inner.lock().unwrap().is_empty()); 299 #[cfg(not(feature = "std"))] 300 assert!(test_repo.storage.inner.lock().is_empty()); 301 302 let psk_id = ResumptionPsk { 303 psk_epoch: 0, 304 psk_group_id: PskGroupId(test_repo.group_id.clone()), 305 usage: ResumptionPSKUsage::Application, 306 }; 307 308 // Make sure you can recall an epoch sitting as a pending insert 309 let resumption = test_repo.resumption_secret(&psk_id).await.unwrap(); 310 let prior_epoch = test_repo.get_epoch_mut(0).await.unwrap().cloned(); 311 312 assert_eq!( 313 prior_epoch.clone().unwrap().secrets.resumption_secret, 314 resumption.unwrap() 315 ); 316 317 assert_eq!(prior_epoch.unwrap(), test_epoch); 318 319 // Write to the storage 320 let snapshot = test_snapshot(test_epoch.epoch_id()).await; 321 test_repo.write_to_storage(snapshot.clone()).await.unwrap(); 322 323 // Make sure the memory cache cleared 324 assert!(test_repo.pending_commit.inserts.is_empty()); 325 assert!(test_repo.pending_commit.updates.is_empty()); 326 327 // Make sure the storage was written 328 #[cfg(feature = "std")] 329 let storage = test_repo.storage.inner.lock().unwrap(); 330 #[cfg(not(feature = "std"))] 331 let storage = test_repo.storage.inner.lock(); 332 333 assert_eq!(storage.len(), 1); 334 335 let stored = storage.get(TEST_GROUP).unwrap(); 336 337 assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap()); 338 339 assert_eq!(stored.epoch_data.len(), 1); 340 341 assert_eq!( 342 stored.epoch_data.back().unwrap(), 343 &EpochRecord::new( 344 test_epoch.epoch_id(), 345 test_epoch.mls_encode_to_vec().unwrap() 346 ) 347 ); 348 } 349 350 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_updates()351 async fn test_updates() { 352 let mut test_repo = test_group_state_repo(2); 353 let test_epoch_0 = test_epoch(0); 354 355 test_repo.insert(test_epoch_0.clone()).await.unwrap(); 356 357 test_repo 358 .write_to_storage(test_snapshot(0).await) 359 .await 360 .unwrap(); 361 362 // Update the stored epoch 363 let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap(); 364 assert_eq!(to_update, &test_epoch_0); 365 366 let new_sender_secret = random_bytes(32); 367 to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret); 368 let to_update = to_update.clone(); 369 370 assert_eq!(test_repo.pending_commit.updates.len(), 1); 371 assert!(test_repo.pending_commit.inserts.is_empty()); 372 373 assert_eq!( 374 test_repo.pending_commit.updates.first().unwrap(), 375 &to_update 376 ); 377 378 // Make sure you can access an epoch pending update 379 let psk_id = ResumptionPsk { 380 psk_epoch: 0, 381 psk_group_id: PskGroupId(test_repo.group_id.clone()), 382 usage: ResumptionPSKUsage::Application, 383 }; 384 385 let owned = test_repo.resumption_secret(&psk_id).await.unwrap(); 386 assert_eq!(owned.as_ref(), Some(&to_update.secrets.resumption_secret)); 387 388 // Write the update to storage 389 let snapshot = test_snapshot(1).await; 390 test_repo.write_to_storage(snapshot.clone()).await.unwrap(); 391 392 assert!(test_repo.pending_commit.updates.is_empty()); 393 assert!(test_repo.pending_commit.inserts.is_empty()); 394 395 // Make sure the storage was written 396 #[cfg(feature = "std")] 397 let storage = test_repo.storage.inner.lock().unwrap(); 398 #[cfg(not(feature = "std"))] 399 let storage = test_repo.storage.inner.lock(); 400 401 assert_eq!(storage.len(), 1); 402 403 let stored = storage.get(TEST_GROUP).unwrap(); 404 405 assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap()); 406 407 assert_eq!(stored.epoch_data.len(), 1); 408 409 assert_eq!( 410 stored.epoch_data.back().unwrap(), 411 &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap()) 412 ); 413 } 414 415 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_insert_and_update()416 async fn test_insert_and_update() { 417 let mut test_repo = test_group_state_repo(2); 418 let test_epoch_0 = test_epoch(0); 419 420 test_repo.insert(test_epoch_0).await.unwrap(); 421 422 test_repo 423 .write_to_storage(test_snapshot(0).await) 424 .await 425 .unwrap(); 426 427 // Update the stored epoch 428 let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap(); 429 let new_sender_secret = random_bytes(32); 430 to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret); 431 let to_update = to_update.clone(); 432 433 // Insert another epoch 434 let test_epoch_1 = test_epoch(1); 435 test_repo.insert(test_epoch_1.clone()).await.unwrap(); 436 437 test_repo 438 .write_to_storage(test_snapshot(1).await) 439 .await 440 .unwrap(); 441 442 assert!(test_repo.pending_commit.inserts.is_empty()); 443 assert!(test_repo.pending_commit.updates.is_empty()); 444 445 // Make sure the storage was written 446 #[cfg(feature = "std")] 447 let storage = test_repo.storage.inner.lock().unwrap(); 448 #[cfg(not(feature = "std"))] 449 let storage = test_repo.storage.inner.lock(); 450 451 assert_eq!(storage.len(), 1); 452 453 let stored = storage.get(TEST_GROUP).unwrap(); 454 455 assert_eq!(stored.epoch_data.len(), 2); 456 457 assert_eq!( 458 stored.epoch_data.front().unwrap(), 459 &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap()) 460 ); 461 462 assert_eq!( 463 stored.epoch_data.back().unwrap(), 464 &EpochRecord::new( 465 test_epoch_1.epoch_id(), 466 test_epoch_1.mls_encode_to_vec().unwrap() 467 ) 468 ); 469 } 470 471 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_many_epochs_in_storage()472 async fn test_many_epochs_in_storage() { 473 let epochs = (0..10).map(test_epoch).collect::<Vec<_>>(); 474 475 let mut test_repo = test_group_state_repo(10); 476 477 for epoch in epochs.iter().cloned() { 478 test_repo.insert(epoch).await.unwrap() 479 } 480 481 test_repo 482 .write_to_storage(test_snapshot(9).await) 483 .await 484 .unwrap(); 485 486 for mut epoch in epochs { 487 let res = test_repo.get_epoch_mut(epoch.epoch_id()).await.unwrap(); 488 489 assert_eq!(res, Some(&mut epoch)); 490 } 491 } 492 493 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_stored_groups_list()494 async fn test_stored_groups_list() { 495 let mut test_repo = test_group_state_repo(2); 496 let test_epoch_0 = test_epoch(0); 497 498 test_repo.insert(test_epoch_0.clone()).await.unwrap(); 499 500 test_repo 501 .write_to_storage(test_snapshot(0).await) 502 .await 503 .unwrap(); 504 505 assert_eq!( 506 test_repo.storage.stored_groups(), 507 vec![test_epoch_0.context.group_id] 508 ) 509 } 510 511 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] reducing_retention_limit_takes_effect_on_epoch_access()512 async fn reducing_retention_limit_takes_effect_on_epoch_access() { 513 let mut repo = test_group_state_repo(1); 514 515 repo.insert(test_epoch(0)).await.unwrap(); 516 repo.insert(test_epoch(1)).await.unwrap(); 517 518 repo.write_to_storage(test_snapshot(0).await).await.unwrap(); 519 520 let mut repo = GroupStateRepository { 521 storage: repo.storage, 522 ..test_group_state_repo(1) 523 }; 524 525 let res = repo.get_epoch_mut(0).await.unwrap(); 526 527 assert!(res.is_none()); 528 } 529 530 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] in_memory_storage_obeys_retention_limit_after_saving()531 async fn in_memory_storage_obeys_retention_limit_after_saving() { 532 let mut repo = test_group_state_repo(1); 533 534 repo.insert(test_epoch(0)).await.unwrap(); 535 repo.write_to_storage(test_snapshot(0).await).await.unwrap(); 536 repo.insert(test_epoch(1)).await.unwrap(); 537 repo.write_to_storage(test_snapshot(1).await).await.unwrap(); 538 539 #[cfg(feature = "std")] 540 let lock = repo.storage.inner.lock().unwrap(); 541 #[cfg(not(feature = "std"))] 542 let lock = repo.storage.inner.lock(); 543 544 assert_eq!(lock.get(TEST_GROUP).unwrap().epoch_data.len(), 1); 545 } 546 547 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] used_key_package_is_deleted()548 async fn used_key_package_is_deleted() { 549 let key_package_repo = InMemoryKeyPackageStorage::default(); 550 551 let key_package = test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"member") 552 .await 553 .0; 554 555 let (id, data) = key_package.to_storage().unwrap(); 556 557 key_package_repo.insert(id, data); 558 559 let mut repo = GroupStateRepository::new( 560 TEST_GROUP.to_vec(), 561 InMemoryGroupStateStorage::new(), 562 key_package_repo, 563 Some(key_package.reference.clone()), 564 ) 565 .unwrap(); 566 567 repo.key_package_repo.get(&key_package.reference).unwrap(); 568 569 repo.write_to_storage(test_snapshot(4).await).await.unwrap(); 570 571 assert!(repo.key_package_repo.get(&key_package.reference).is_none()); 572 } 573 } 574