1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use alloc::vec::Vec; 6 use mls_rs_core::{ 7 crypto::CipherSuiteProvider, 8 error::IntoAnyError, 9 group::GroupStateStorage, 10 key_package::KeyPackageStorage, 11 psk::{ExternalPskId, PreSharedKey, PreSharedKeyStorage}, 12 }; 13 14 use crate::{ 15 client::MlsError, 16 group::{epoch::EpochSecrets, state_repo::GroupStateRepository, GroupContext}, 17 psk::secret::PskSecret, 18 }; 19 20 use super::{secret::PskSecretInput, JustPreSharedKeyID, PreSharedKeyID, ResumptionPsk}; 21 22 pub(crate) struct PskResolver<'a, GS, K, PS> 23 where 24 GS: GroupStateStorage, 25 PS: PreSharedKeyStorage, 26 K: KeyPackageStorage, 27 { 28 pub group_context: Option<&'a GroupContext>, 29 pub current_epoch: Option<&'a EpochSecrets>, 30 pub prior_epochs: Option<&'a GroupStateRepository<GS, K>>, 31 pub psk_store: &'a PS, 32 } 33 34 impl<GS: GroupStateStorage, K: KeyPackageStorage, PS: PreSharedKeyStorage> 35 PskResolver<'_, GS, K, PS> 36 { 37 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] resolve_resumption(&self, psk_id: &ResumptionPsk) -> Result<PreSharedKey, MlsError>38 async fn resolve_resumption(&self, psk_id: &ResumptionPsk) -> Result<PreSharedKey, MlsError> { 39 if let Some(ctx) = self.group_context { 40 if ctx.epoch == psk_id.psk_epoch && ctx.group_id == psk_id.psk_group_id.0 { 41 let epoch = self.current_epoch.ok_or(MlsError::OldGroupStateNotFound)?; 42 return Ok(epoch.resumption_secret.clone()); 43 } 44 } 45 46 #[cfg(feature = "prior_epoch")] 47 if let Some(eps) = self.prior_epochs { 48 if let Some(psk) = eps.resumption_secret(psk_id).await? { 49 return Ok(psk); 50 } 51 } 52 53 Err(MlsError::OldGroupStateNotFound) 54 } 55 56 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] resolve_external(&self, psk_id: &ExternalPskId) -> Result<PreSharedKey, MlsError>57 async fn resolve_external(&self, psk_id: &ExternalPskId) -> Result<PreSharedKey, MlsError> { 58 self.psk_store 59 .get(psk_id) 60 .await 61 .map_err(|e| MlsError::PskStoreError(e.into_any_error()))? 62 .ok_or(MlsError::MissingRequiredPsk) 63 } 64 65 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] resolve(&self, id: &[PreSharedKeyID]) -> Result<Vec<PskSecretInput>, MlsError>66 async fn resolve(&self, id: &[PreSharedKeyID]) -> Result<Vec<PskSecretInput>, MlsError> { 67 let mut secret_inputs = Vec::new(); 68 69 for id in id { 70 let psk = match &id.key_id { 71 JustPreSharedKeyID::External(external) => self.resolve_external(external).await, 72 JustPreSharedKeyID::Resumption(resumption) => { 73 self.resolve_resumption(resumption).await 74 } 75 }?; 76 77 secret_inputs.push(PskSecretInput { 78 id: id.clone(), 79 psk, 80 }) 81 } 82 83 Ok(secret_inputs) 84 } 85 86 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] resolve_to_secret<P: CipherSuiteProvider>( &self, id: &[PreSharedKeyID], cipher_suite_provider: &P, ) -> Result<PskSecret, MlsError>87 pub async fn resolve_to_secret<P: CipherSuiteProvider>( 88 &self, 89 id: &[PreSharedKeyID], 90 cipher_suite_provider: &P, 91 ) -> Result<PskSecret, MlsError> { 92 let psk = self.resolve(id).await?; 93 PskSecret::calculate(&psk, cipher_suite_provider).await 94 } 95 } 96