1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use mls_rs_core::{
6     error::IntoAnyError, identity::IdentityProvider, key_package::KeyPackageStorage,
7 };
8 
9 use crate::{
10     cipher_suite::CipherSuite,
11     client::MlsError,
12     extension::RatchetTreeExt,
13     key_package::KeyPackageGeneration,
14     protocol_version::ProtocolVersion,
15     signer::Signable,
16     tree_kem::{node::LeafIndex, tree_validator::TreeValidator, TreeKemPublic},
17     CipherSuiteProvider, CryptoProvider,
18 };
19 
20 #[cfg(feature = "by_ref_proposal")]
21 use crate::extension::ExternalSendersExt;
22 
23 use super::{
24     framing::Sender, message_signature::AuthenticatedContent,
25     transcript_hash::InterimTranscriptHash, ConfirmedTranscriptHash, EncryptedGroupSecrets,
26     ExportedTree, GroupInfo, GroupState,
27 };
28 
29 use super::message_processor::ProvisionalState;
30 
31 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
validate_group_info_common<C: CipherSuiteProvider>( msg_version: ProtocolVersion, group_info: &GroupInfo, tree: &TreeKemPublic, cs: &C, ) -> Result<(), MlsError>32 pub(crate) async fn validate_group_info_common<C: CipherSuiteProvider>(
33     msg_version: ProtocolVersion,
34     group_info: &GroupInfo,
35     tree: &TreeKemPublic,
36     cs: &C,
37 ) -> Result<(), MlsError> {
38     if msg_version != group_info.group_context.protocol_version {
39         return Err(MlsError::ProtocolVersionMismatch);
40     }
41 
42     if group_info.group_context.cipher_suite != cs.cipher_suite() {
43         return Err(MlsError::CipherSuiteMismatch);
44     }
45 
46     let sender_leaf = &tree.get_leaf_node(group_info.signer)?;
47 
48     group_info
49         .verify(cs, &sender_leaf.signing_identity.signature_key, &())
50         .await?;
51 
52     Ok(())
53 }
54 
55 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
validate_group_info_member<C: CipherSuiteProvider>( self_state: &GroupState, msg_version: ProtocolVersion, group_info: &GroupInfo, cs: &C, ) -> Result<(), MlsError>56 pub(crate) async fn validate_group_info_member<C: CipherSuiteProvider>(
57     self_state: &GroupState,
58     msg_version: ProtocolVersion,
59     group_info: &GroupInfo,
60     cs: &C,
61 ) -> Result<(), MlsError> {
62     validate_group_info_common(msg_version, group_info, &self_state.public_tree, cs).await?;
63 
64     let self_tree = ExportedTree::new_borrowed(&self_state.public_tree.nodes);
65 
66     if let Some(tree) = group_info.extensions.get_as::<RatchetTreeExt>()? {
67         (tree.tree_data == self_tree)
68             .then_some(())
69             .ok_or(MlsError::InvalidGroupInfo)?;
70     }
71 
72     (group_info.group_context == self_state.context
73         && group_info.confirmation_tag == self_state.confirmation_tag)
74         .then_some(())
75         .ok_or(MlsError::InvalidGroupInfo)?;
76 
77     Ok(())
78 }
79 
80 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
validate_group_info_joiner<C, I>( msg_version: ProtocolVersion, group_info: &GroupInfo, tree: Option<ExportedTree<'_>>, id_provider: &I, cs: &C, ) -> Result<TreeKemPublic, MlsError> where C: CipherSuiteProvider, I: IdentityProvider,81 pub(crate) async fn validate_group_info_joiner<C, I>(
82     msg_version: ProtocolVersion,
83     group_info: &GroupInfo,
84     tree: Option<ExportedTree<'_>>,
85     id_provider: &I,
86     cs: &C,
87 ) -> Result<TreeKemPublic, MlsError>
88 where
89     C: CipherSuiteProvider,
90     I: IdentityProvider,
91 {
92     let tree = match group_info.extensions.get_as::<RatchetTreeExt>()? {
93         Some(ext) => ext.tree_data,
94         None => tree.ok_or(MlsError::RatchetTreeNotFound)?,
95     };
96 
97     let context = &group_info.group_context;
98 
99     let mut tree =
100         TreeKemPublic::import_node_data(tree.into(), id_provider, &context.extensions).await?;
101 
102     // Verify the integrity of the ratchet tree
103     TreeValidator::new(cs, context, id_provider)
104         .validate(&mut tree)
105         .await?;
106 
107     #[cfg(feature = "by_ref_proposal")]
108     if let Some(ext_senders) = context.extensions.get_as::<ExternalSendersExt>()? {
109         // TODO do joiners verify group against current time??
110         ext_senders
111             .verify_all(id_provider, None, &context.extensions)
112             .await
113             .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
114     }
115 
116     validate_group_info_common(msg_version, group_info, &tree, cs).await?;
117 
118     Ok(tree)
119 }
120 
commit_sender( sender: &Sender, provisional_state: &ProvisionalState, ) -> Result<LeafIndex, MlsError>121 pub(crate) fn commit_sender(
122     sender: &Sender,
123     provisional_state: &ProvisionalState,
124 ) -> Result<LeafIndex, MlsError> {
125     match sender {
126         Sender::Member(index) => Ok(LeafIndex(*index)),
127         #[cfg(feature = "by_ref_proposal")]
128         Sender::External(_) => Err(MlsError::ExternalSenderCannotCommit),
129         #[cfg(feature = "by_ref_proposal")]
130         Sender::NewMemberProposal => Err(MlsError::ExpectedAddProposalForNewMemberProposal),
131         Sender::NewMemberCommit => provisional_state
132             .external_init_index
133             .ok_or(MlsError::ExternalCommitMissingExternalInit),
134     }
135 }
136 
137 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
transcript_hashes<P: CipherSuiteProvider>( cipher_suite_provider: &P, prev_interim_transcript_hash: &InterimTranscriptHash, content: &AuthenticatedContent, ) -> Result<(InterimTranscriptHash, ConfirmedTranscriptHash), MlsError>138 pub(super) async fn transcript_hashes<P: CipherSuiteProvider>(
139     cipher_suite_provider: &P,
140     prev_interim_transcript_hash: &InterimTranscriptHash,
141     content: &AuthenticatedContent,
142 ) -> Result<(InterimTranscriptHash, ConfirmedTranscriptHash), MlsError> {
143     let confirmed_transcript_hash = ConfirmedTranscriptHash::create(
144         cipher_suite_provider,
145         prev_interim_transcript_hash,
146         content,
147     )
148     .await?;
149 
150     let confirmation_tag = content
151         .auth
152         .confirmation_tag
153         .as_ref()
154         .ok_or(MlsError::InvalidConfirmationTag)?;
155 
156     let interim_transcript_hash = InterimTranscriptHash::create(
157         cipher_suite_provider,
158         &confirmed_transcript_hash,
159         confirmation_tag,
160     )
161     .await?;
162 
163     Ok((interim_transcript_hash, confirmed_transcript_hash))
164 }
165 
166 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
find_key_package_generation<'a, K: KeyPackageStorage>( key_package_repo: &K, secrets: &'a [EncryptedGroupSecrets], ) -> Result<(&'a EncryptedGroupSecrets, KeyPackageGeneration), MlsError>167 pub(crate) async fn find_key_package_generation<'a, K: KeyPackageStorage>(
168     key_package_repo: &K,
169     secrets: &'a [EncryptedGroupSecrets],
170 ) -> Result<(&'a EncryptedGroupSecrets, KeyPackageGeneration), MlsError> {
171     for secret in secrets {
172         if let Some(val) = key_package_repo
173             .get(&secret.new_member)
174             .await
175             .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))
176             .and_then(|maybe_data| {
177                 if let Some(data) = maybe_data {
178                     KeyPackageGeneration::from_storage(secret.new_member.to_vec(), data)
179                         .map(|kpg| Some((secret, kpg)))
180                 } else {
181                     Ok::<_, MlsError>(None)
182                 }
183             })?
184         {
185             return Ok(val);
186         }
187     }
188 
189     Err(MlsError::WelcomeKeyPackageNotFound)
190 }
191 
cipher_suite_provider<P>( crypto: P, cipher_suite: CipherSuite, ) -> Result<P::CipherSuiteProvider, MlsError> where P: CryptoProvider,192 pub(crate) fn cipher_suite_provider<P>(
193     crypto: P,
194     cipher_suite: CipherSuite,
195 ) -> Result<P::CipherSuiteProvider, MlsError>
196 where
197     P: CryptoProvider,
198 {
199     crypto
200         .cipher_suite_provider(cipher_suite)
201         .ok_or(MlsError::UnsupportedCipherSuite(cipher_suite))
202 }
203