1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec::Vec;
6 use mls_rs_core::{
7     crypto::CipherSuiteProvider,
8     error::IntoAnyError,
9     group::GroupStateStorage,
10     key_package::KeyPackageStorage,
11     psk::{ExternalPskId, PreSharedKey, PreSharedKeyStorage},
12 };
13 
14 use crate::{
15     client::MlsError,
16     group::{epoch::EpochSecrets, state_repo::GroupStateRepository, GroupContext},
17     psk::secret::PskSecret,
18 };
19 
20 use super::{secret::PskSecretInput, JustPreSharedKeyID, PreSharedKeyID, ResumptionPsk};
21 
22 pub(crate) struct PskResolver<'a, GS, K, PS>
23 where
24     GS: GroupStateStorage,
25     PS: PreSharedKeyStorage,
26     K: KeyPackageStorage,
27 {
28     pub group_context: Option<&'a GroupContext>,
29     pub current_epoch: Option<&'a EpochSecrets>,
30     pub prior_epochs: Option<&'a GroupStateRepository<GS, K>>,
31     pub psk_store: &'a PS,
32 }
33 
34 impl<GS: GroupStateStorage, K: KeyPackageStorage, PS: PreSharedKeyStorage>
35     PskResolver<'_, GS, K, PS>
36 {
37     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
resolve_resumption(&self, psk_id: &ResumptionPsk) -> Result<PreSharedKey, MlsError>38     async fn resolve_resumption(&self, psk_id: &ResumptionPsk) -> Result<PreSharedKey, MlsError> {
39         if let Some(ctx) = self.group_context {
40             if ctx.epoch == psk_id.psk_epoch && ctx.group_id == psk_id.psk_group_id.0 {
41                 let epoch = self.current_epoch.ok_or(MlsError::OldGroupStateNotFound)?;
42                 return Ok(epoch.resumption_secret.clone());
43             }
44         }
45 
46         #[cfg(feature = "prior_epoch")]
47         if let Some(eps) = self.prior_epochs {
48             if let Some(psk) = eps.resumption_secret(psk_id).await? {
49                 return Ok(psk);
50             }
51         }
52 
53         Err(MlsError::OldGroupStateNotFound)
54     }
55 
56     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
resolve_external(&self, psk_id: &ExternalPskId) -> Result<PreSharedKey, MlsError>57     async fn resolve_external(&self, psk_id: &ExternalPskId) -> Result<PreSharedKey, MlsError> {
58         self.psk_store
59             .get(psk_id)
60             .await
61             .map_err(|e| MlsError::PskStoreError(e.into_any_error()))?
62             .ok_or(MlsError::MissingRequiredPsk)
63     }
64 
65     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
resolve(&self, id: &[PreSharedKeyID]) -> Result<Vec<PskSecretInput>, MlsError>66     async fn resolve(&self, id: &[PreSharedKeyID]) -> Result<Vec<PskSecretInput>, MlsError> {
67         let mut secret_inputs = Vec::new();
68 
69         for id in id {
70             let psk = match &id.key_id {
71                 JustPreSharedKeyID::External(external) => self.resolve_external(external).await,
72                 JustPreSharedKeyID::Resumption(resumption) => {
73                     self.resolve_resumption(resumption).await
74                 }
75             }?;
76 
77             secret_inputs.push(PskSecretInput {
78                 id: id.clone(),
79                 psk,
80             })
81         }
82 
83         Ok(secret_inputs)
84     }
85 
86     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
resolve_to_secret<P: CipherSuiteProvider>( &self, id: &[PreSharedKeyID], cipher_suite_provider: &P, ) -> Result<PskSecret, MlsError>87     pub async fn resolve_to_secret<P: CipherSuiteProvider>(
88         &self,
89         id: &[PreSharedKeyID],
90         cipher_suite_provider: &P,
91     ) -> Result<PskSecret, MlsError> {
92         let psk = self.resolve(id).await?;
93         PskSecret::calculate(&psk, cipher_suite_provider).await
94     }
95 }
96