1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use super::{
6     commit_sender,
7     confirmation_tag::ConfirmationTag,
8     framing::{
9         ApplicationData, Content, ContentType, MlsMessage, MlsMessagePayload, PublicMessage, Sender,
10     },
11     message_signature::AuthenticatedContent,
12     mls_rules::{CommitDirection, MlsRules},
13     proposal_filter::ProposalBundle,
14     state::GroupState,
15     transcript_hash::InterimTranscriptHash,
16     transcript_hashes, validate_group_info_member, GroupContext, GroupInfo, Welcome,
17 };
18 use crate::{
19     client::MlsError,
20     key_package::validate_key_package_properties,
21     time::MlsTime,
22     tree_kem::{
23         leaf_node_validator::{LeafNodeValidator, ValidationContext},
24         node::LeafIndex,
25         path_secret::PathSecret,
26         validate_update_path, TreeKemPrivate, TreeKemPublic, ValidatedUpdatePath,
27     },
28     CipherSuiteProvider, KeyPackage,
29 };
30 #[cfg(mls_build_async)]
31 use alloc::boxed::Box;
32 use alloc::vec::Vec;
33 use core::fmt::{self, Debug};
34 use mls_rs_core::{
35     identity::IdentityProvider, protocol_version::ProtocolVersion, psk::PreSharedKeyStorage,
36 };
37 
38 #[cfg(feature = "by_ref_proposal")]
39 use super::proposal_ref::ProposalRef;
40 
41 #[cfg(not(feature = "by_ref_proposal"))]
42 use crate::group::proposal_cache::resolve_for_commit;
43 
44 #[cfg(feature = "by_ref_proposal")]
45 use super::proposal::Proposal;
46 
47 #[cfg(feature = "custom_proposal")]
48 use super::proposal_filter::ProposalInfo;
49 
50 #[cfg(feature = "state_update")]
51 use mls_rs_core::{
52     crypto::CipherSuite,
53     group::{MemberUpdate, RosterUpdate},
54 };
55 
56 #[cfg(all(feature = "state_update", feature = "psk"))]
57 use mls_rs_core::psk::ExternalPskId;
58 
59 #[cfg(feature = "state_update")]
60 use crate::tree_kem::UpdatePath;
61 
62 #[cfg(feature = "state_update")]
63 use super::{member_from_key_package, member_from_leaf_node};
64 
65 #[cfg(all(feature = "state_update", feature = "custom_proposal"))]
66 use super::proposal::CustomProposal;
67 
68 #[cfg(feature = "private_message")]
69 use crate::group::framing::PrivateMessage;
70 
71 #[cfg(feature = "by_ref_proposal")]
72 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
73 
74 #[derive(Debug)]
75 pub(crate) struct ProvisionalState {
76     pub(crate) public_tree: TreeKemPublic,
77     pub(crate) applied_proposals: ProposalBundle,
78     pub(crate) group_context: GroupContext,
79     pub(crate) external_init_index: Option<LeafIndex>,
80     pub(crate) indexes_of_added_kpkgs: Vec<LeafIndex>,
81     #[cfg(feature = "by_ref_proposal")]
82     pub(crate) unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
83 }
84 
85 //By default, the path field of a Commit MUST be populated. The path field MAY be omitted if
86 //(a) it covers at least one proposal and (b) none of the proposals covered by the Commit are
87 //of "path required" types. A proposal type requires a path if it cannot change the group
88 //membership in a way that requires the forward secrecy and post-compromise security guarantees
89 //that an UpdatePath provides. The only proposal types defined in this document that do not
90 //require a path are:
91 
92 // add
93 // psk
94 // reinit
path_update_required(proposals: &ProposalBundle) -> bool95 pub(crate) fn path_update_required(proposals: &ProposalBundle) -> bool {
96     let res = proposals.external_init_proposals().first().is_some();
97 
98     #[cfg(feature = "by_ref_proposal")]
99     let res = res || !proposals.update_proposals().is_empty();
100 
101     res || proposals.length() == 0
102         || proposals.group_context_extensions_proposal().is_some()
103         || !proposals.remove_proposals().is_empty()
104 }
105 
106 /// Representation of changes made by a [commit](crate::Group::commit).
107 #[cfg(feature = "state_update")]
108 #[derive(Clone, Debug, PartialEq)]
109 pub struct StateUpdate {
110     pub(crate) roster_update: RosterUpdate,
111     #[cfg(feature = "psk")]
112     pub(crate) added_psks: Vec<ExternalPskId>,
113     pub(crate) pending_reinit: Option<CipherSuite>,
114     pub(crate) active: bool,
115     pub(crate) epoch: u64,
116     #[cfg(feature = "custom_proposal")]
117     pub(crate) custom_proposals: Vec<ProposalInfo<CustomProposal>>,
118     #[cfg(feature = "by_ref_proposal")]
119     pub(crate) unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
120 }
121 
122 #[cfg(not(feature = "state_update"))]
123 #[non_exhaustive]
124 #[derive(Clone, Debug, PartialEq)]
125 pub struct StateUpdate {}
126 
127 #[cfg(feature = "state_update")]
128 impl StateUpdate {
129     /// Changes to the roster as a result of proposals.
roster_update(&self) -> &RosterUpdate130     pub fn roster_update(&self) -> &RosterUpdate {
131         &self.roster_update
132     }
133 
134     #[cfg(feature = "psk")]
135     /// Pre-shared keys that have been added to the group.
added_psks(&self) -> &[ExternalPskId]136     pub fn added_psks(&self) -> &[ExternalPskId] {
137         &self.added_psks
138     }
139 
140     /// Flag to indicate if the group is now pending reinitialization due to
141     /// receiving a [`ReInit`](crate::group::proposal::Proposal::ReInit)
142     /// proposal.
is_pending_reinit(&self) -> bool143     pub fn is_pending_reinit(&self) -> bool {
144         self.pending_reinit.is_some()
145     }
146 
147     /// Flag to indicate the group is still active. This will be false if the
148     /// member processing the commit has been removed from the group.
is_active(&self) -> bool149     pub fn is_active(&self) -> bool {
150         self.active
151     }
152 
153     /// The new epoch of the group state.
new_epoch(&self) -> u64154     pub fn new_epoch(&self) -> u64 {
155         self.epoch
156     }
157 
158     /// Custom proposals that were committed to.
159     #[cfg(feature = "custom_proposal")]
custom_proposals(&self) -> &[ProposalInfo<CustomProposal>]160     pub fn custom_proposals(&self) -> &[ProposalInfo<CustomProposal>] {
161         &self.custom_proposals
162     }
163 
164     /// Proposals that were received in the prior epoch but not committed to.
165     #[cfg(feature = "by_ref_proposal")]
unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>]166     pub fn unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>] {
167         &self.unused_proposals
168     }
169 
pending_reinit_ciphersuite(&self) -> Option<CipherSuite>170     pub fn pending_reinit_ciphersuite(&self) -> Option<CipherSuite> {
171         self.pending_reinit
172     }
173 }
174 
175 #[cfg_attr(
176     all(feature = "ffi", not(test)),
177     safer_ffi_gen::ffi_type(clone, opaque)
178 )]
179 #[derive(Debug, Clone)]
180 #[allow(clippy::large_enum_variant)]
181 /// An event generated as a result of processing a message for a group with
182 /// [`Group::process_incoming_message`](crate::group::Group::process_incoming_message).
183 pub enum ReceivedMessage {
184     /// An application message was decrypted.
185     ApplicationMessage(ApplicationMessageDescription),
186     /// A new commit was processed creating a new group state.
187     Commit(CommitMessageDescription),
188     /// A proposal was received.
189     Proposal(ProposalMessageDescription),
190     /// Validated GroupInfo object
191     GroupInfo(GroupInfo),
192     /// Validated welcome message
193     Welcome,
194     /// Validated key package
195     KeyPackage(KeyPackage),
196 }
197 
198 impl TryFrom<ApplicationMessageDescription> for ReceivedMessage {
199     type Error = MlsError;
200 
try_from(value: ApplicationMessageDescription) -> Result<Self, Self::Error>201     fn try_from(value: ApplicationMessageDescription) -> Result<Self, Self::Error> {
202         Ok(ReceivedMessage::ApplicationMessage(value))
203     }
204 }
205 
206 impl From<CommitMessageDescription> for ReceivedMessage {
from(value: CommitMessageDescription) -> Self207     fn from(value: CommitMessageDescription) -> Self {
208         ReceivedMessage::Commit(value)
209     }
210 }
211 
212 impl From<ProposalMessageDescription> for ReceivedMessage {
from(value: ProposalMessageDescription) -> Self213     fn from(value: ProposalMessageDescription) -> Self {
214         ReceivedMessage::Proposal(value)
215     }
216 }
217 
218 impl From<GroupInfo> for ReceivedMessage {
from(value: GroupInfo) -> Self219     fn from(value: GroupInfo) -> Self {
220         ReceivedMessage::GroupInfo(value)
221     }
222 }
223 
224 impl From<Welcome> for ReceivedMessage {
from(_: Welcome) -> Self225     fn from(_: Welcome) -> Self {
226         ReceivedMessage::Welcome
227     }
228 }
229 
230 impl From<KeyPackage> for ReceivedMessage {
from(value: KeyPackage) -> Self231     fn from(value: KeyPackage) -> Self {
232         ReceivedMessage::KeyPackage(value)
233     }
234 }
235 
236 #[cfg_attr(
237     all(feature = "ffi", not(test)),
238     safer_ffi_gen::ffi_type(clone, opaque)
239 )]
240 #[derive(Clone, PartialEq, Eq)]
241 /// Description of a MLS application message.
242 pub struct ApplicationMessageDescription {
243     /// Index of this user in the group state.
244     pub sender_index: u32,
245     /// Received application data.
246     data: ApplicationData,
247     /// Plaintext authenticated data in the received MLS packet.
248     pub authenticated_data: Vec<u8>,
249 }
250 
251 impl Debug for ApplicationMessageDescription {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result252     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253         f.debug_struct("ApplicationMessageDescription")
254             .field("sender_index", &self.sender_index)
255             .field("data", &self.data)
256             .field(
257                 "authenticated_data",
258                 &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
259             )
260             .finish()
261     }
262 }
263 
264 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
265 impl ApplicationMessageDescription {
data(&self) -> &[u8]266     pub fn data(&self) -> &[u8] {
267         self.data.as_bytes()
268     }
269 }
270 
271 #[cfg_attr(
272     all(feature = "ffi", not(test)),
273     safer_ffi_gen::ffi_type(clone, opaque)
274 )]
275 #[derive(Clone, PartialEq)]
276 #[non_exhaustive]
277 /// Description of a processed MLS commit message.
278 pub struct CommitMessageDescription {
279     /// True if this is the result of an external commit.
280     pub is_external: bool,
281     /// The index in the group state of the member who performed this commit.
282     pub committer: u32,
283     /// A full description of group state changes as a result of this commit.
284     pub state_update: StateUpdate,
285     /// Plaintext authenticated data in the received MLS packet.
286     pub authenticated_data: Vec<u8>,
287 }
288 
289 impl Debug for CommitMessageDescription {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result290     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291         f.debug_struct("CommitMessageDescription")
292             .field("is_external", &self.is_external)
293             .field("committer", &self.committer)
294             .field("state_update", &self.state_update)
295             .field(
296                 "authenticated_data",
297                 &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
298             )
299             .finish()
300     }
301 }
302 
303 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
304 /// Proposal sender type.
305 pub enum ProposalSender {
306     /// A current member of the group by index in the group state.
307     Member(u32),
308     /// An external entity by index within an
309     /// [`ExternalSendersExt`](crate::extension::built_in::ExternalSendersExt).
310     External(u32),
311     /// A new member proposing their addition to the group.
312     NewMember,
313 }
314 
315 impl TryFrom<Sender> for ProposalSender {
316     type Error = MlsError;
317 
try_from(value: Sender) -> Result<Self, Self::Error>318     fn try_from(value: Sender) -> Result<Self, Self::Error> {
319         match value {
320             Sender::Member(index) => Ok(Self::Member(index)),
321             #[cfg(feature = "by_ref_proposal")]
322             Sender::External(index) => Ok(Self::External(index)),
323             #[cfg(feature = "by_ref_proposal")]
324             Sender::NewMemberProposal => Ok(Self::NewMember),
325             Sender::NewMemberCommit => Err(MlsError::InvalidSender),
326         }
327     }
328 }
329 
330 #[cfg(feature = "by_ref_proposal")]
331 #[cfg_attr(
332     all(feature = "ffi", not(test)),
333     safer_ffi_gen::ffi_type(clone, opaque)
334 )]
335 #[derive(Clone)]
336 #[non_exhaustive]
337 /// Description of a processed MLS proposal message.
338 pub struct ProposalMessageDescription {
339     /// Sender of the proposal.
340     pub sender: ProposalSender,
341     /// Proposal content.
342     pub proposal: Proposal,
343     /// Plaintext authenticated data in the received MLS packet.
344     pub authenticated_data: Vec<u8>,
345     /// Proposal reference.
346     pub proposal_ref: ProposalRef,
347 }
348 
349 #[cfg(feature = "by_ref_proposal")]
350 impl Debug for ProposalMessageDescription {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result351     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352         f.debug_struct("ProposalMessageDescription")
353             .field("sender", &self.sender)
354             .field("proposal", &self.proposal)
355             .field(
356                 "authenticated_data",
357                 &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
358             )
359             .field("proposal_ref", &self.proposal_ref)
360             .finish()
361     }
362 }
363 
364 #[cfg(feature = "by_ref_proposal")]
365 #[derive(MlsSize, MlsEncode, MlsDecode)]
366 pub struct CachedProposal {
367     pub(crate) proposal: Proposal,
368     pub(crate) proposal_ref: ProposalRef,
369     pub(crate) sender: Sender,
370 }
371 
372 #[cfg(feature = "by_ref_proposal")]
373 impl CachedProposal {
374     /// Deserialize the proposal
from_bytes(bytes: &[u8]) -> Result<Self, MlsError>375     pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
376         Ok(Self::mls_decode(&mut &*bytes)?)
377     }
378 
379     /// Serialize the proposal
to_bytes(&self) -> Result<Vec<u8>, MlsError>380     pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
381         Ok(self.mls_encode_to_vec()?)
382     }
383 }
384 
385 #[cfg(feature = "by_ref_proposal")]
386 impl ProposalMessageDescription {
cached_proposal(self) -> CachedProposal387     pub fn cached_proposal(self) -> CachedProposal {
388         let sender = match self.sender {
389             ProposalSender::Member(i) => Sender::Member(i),
390             ProposalSender::External(i) => Sender::External(i),
391             ProposalSender::NewMember => Sender::NewMemberProposal,
392         };
393 
394         CachedProposal {
395             proposal: self.proposal,
396             proposal_ref: self.proposal_ref,
397             sender,
398         }
399     }
400 
proposal_ref(&self) -> Vec<u8>401     pub fn proposal_ref(&self) -> Vec<u8> {
402         self.proposal_ref.to_vec()
403     }
404 }
405 
406 #[cfg(not(feature = "by_ref_proposal"))]
407 #[cfg_attr(
408     all(feature = "ffi", not(test)),
409     safer_ffi_gen::ffi_type(clone, opaque)
410 )]
411 #[derive(Debug, Clone)]
412 /// Description of a processed MLS proposal message.
413 pub struct ProposalMessageDescription {}
414 
415 #[allow(clippy::large_enum_variant)]
416 pub(crate) enum EventOrContent<E> {
417     #[cfg_attr(
418         not(all(feature = "private_message", feature = "external_client")),
419         allow(dead_code)
420     )]
421     Event(E),
422     Content(AuthenticatedContent),
423 }
424 
425 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
426 #[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
427 #[cfg_attr(
428     all(not(target_arch = "wasm32"), mls_build_async),
429     maybe_async::must_be_async
430 )]
431 pub(crate) trait MessageProcessor: Send + Sync {
432     type OutputType: TryFrom<ApplicationMessageDescription, Error = MlsError>
433         + From<CommitMessageDescription>
434         + From<ProposalMessageDescription>
435         + From<GroupInfo>
436         + From<Welcome>
437         + From<KeyPackage>
438         + Send;
439 
440     type MlsRules: MlsRules;
441     type IdentityProvider: IdentityProvider;
442     type CipherSuiteProvider: CipherSuiteProvider;
443     type PreSharedKeyStorage: PreSharedKeyStorage;
444 
process_incoming_message( &mut self, message: MlsMessage, #[cfg(feature = "by_ref_proposal")] cache_proposal: bool, ) -> Result<Self::OutputType, MlsError>445     async fn process_incoming_message(
446         &mut self,
447         message: MlsMessage,
448         #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
449     ) -> Result<Self::OutputType, MlsError> {
450         self.process_incoming_message_with_time(
451             message,
452             #[cfg(feature = "by_ref_proposal")]
453             cache_proposal,
454             None,
455         )
456         .await
457     }
458 
process_incoming_message_with_time( &mut self, message: MlsMessage, #[cfg(feature = "by_ref_proposal")] cache_proposal: bool, time_sent: Option<MlsTime>, ) -> Result<Self::OutputType, MlsError>459     async fn process_incoming_message_with_time(
460         &mut self,
461         message: MlsMessage,
462         #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
463         time_sent: Option<MlsTime>,
464     ) -> Result<Self::OutputType, MlsError> {
465         let event_or_content = self.get_event_from_incoming_message(message).await?;
466 
467         self.process_event_or_content(
468             event_or_content,
469             #[cfg(feature = "by_ref_proposal")]
470             cache_proposal,
471             time_sent,
472         )
473         .await
474     }
475 
get_event_from_incoming_message( &mut self, message: MlsMessage, ) -> Result<EventOrContent<Self::OutputType>, MlsError>476     async fn get_event_from_incoming_message(
477         &mut self,
478         message: MlsMessage,
479     ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
480         self.check_metadata(&message)?;
481 
482         match message.payload {
483             MlsMessagePayload::Plain(plaintext) => {
484                 self.verify_plaintext_authentication(plaintext).await
485             }
486             #[cfg(feature = "private_message")]
487             MlsMessagePayload::Cipher(cipher_text) => self.process_ciphertext(&cipher_text).await,
488             MlsMessagePayload::GroupInfo(group_info) => {
489                 validate_group_info_member(
490                     self.group_state(),
491                     message.version,
492                     &group_info,
493                     self.cipher_suite_provider(),
494                 )
495                 .await?;
496 
497                 Ok(EventOrContent::Event(group_info.into()))
498             }
499             MlsMessagePayload::Welcome(welcome) => {
500                 self.validate_welcome(&welcome, message.version)?;
501 
502                 Ok(EventOrContent::Event(welcome.into()))
503             }
504             MlsMessagePayload::KeyPackage(key_package) => {
505                 self.validate_key_package(&key_package, message.version)
506                     .await?;
507 
508                 Ok(EventOrContent::Event(key_package.into()))
509             }
510         }
511     }
512 
process_event_or_content( &mut self, event_or_content: EventOrContent<Self::OutputType>, #[cfg(feature = "by_ref_proposal")] cache_proposal: bool, time_sent: Option<MlsTime>, ) -> Result<Self::OutputType, MlsError>513     async fn process_event_or_content(
514         &mut self,
515         event_or_content: EventOrContent<Self::OutputType>,
516         #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
517         time_sent: Option<MlsTime>,
518     ) -> Result<Self::OutputType, MlsError> {
519         let msg = match event_or_content {
520             EventOrContent::Event(event) => event,
521             EventOrContent::Content(content) => {
522                 self.process_auth_content(
523                     content,
524                     #[cfg(feature = "by_ref_proposal")]
525                     cache_proposal,
526                     time_sent,
527                 )
528                 .await?
529             }
530         };
531 
532         Ok(msg)
533     }
534 
process_auth_content( &mut self, auth_content: AuthenticatedContent, #[cfg(feature = "by_ref_proposal")] cache_proposal: bool, time_sent: Option<MlsTime>, ) -> Result<Self::OutputType, MlsError>535     async fn process_auth_content(
536         &mut self,
537         auth_content: AuthenticatedContent,
538         #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
539         time_sent: Option<MlsTime>,
540     ) -> Result<Self::OutputType, MlsError> {
541         let event = match auth_content.content.content {
542             #[cfg(feature = "private_message")]
543             Content::Application(data) => {
544                 let authenticated_data = auth_content.content.authenticated_data;
545                 let sender = auth_content.content.sender;
546 
547                 self.process_application_message(data, sender, authenticated_data)
548                     .and_then(Self::OutputType::try_from)
549             }
550             Content::Commit(_) => self
551                 .process_commit(auth_content, time_sent)
552                 .await
553                 .map(Self::OutputType::from),
554             #[cfg(feature = "by_ref_proposal")]
555             Content::Proposal(ref proposal) => self
556                 .process_proposal(&auth_content, proposal, cache_proposal)
557                 .await
558                 .map(Self::OutputType::from),
559         }?;
560 
561         Ok(event)
562     }
563 
564     #[cfg(feature = "private_message")]
process_application_message( &self, data: ApplicationData, sender: Sender, authenticated_data: Vec<u8>, ) -> Result<ApplicationMessageDescription, MlsError>565     fn process_application_message(
566         &self,
567         data: ApplicationData,
568         sender: Sender,
569         authenticated_data: Vec<u8>,
570     ) -> Result<ApplicationMessageDescription, MlsError> {
571         let Sender::Member(sender_index) = sender else {
572             return Err(MlsError::InvalidSender);
573         };
574 
575         Ok(ApplicationMessageDescription {
576             authenticated_data,
577             sender_index,
578             data,
579         })
580     }
581 
582     #[cfg(feature = "by_ref_proposal")]
583     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
process_proposal( &mut self, auth_content: &AuthenticatedContent, proposal: &Proposal, cache_proposal: bool, ) -> Result<ProposalMessageDescription, MlsError>584     async fn process_proposal(
585         &mut self,
586         auth_content: &AuthenticatedContent,
587         proposal: &Proposal,
588         cache_proposal: bool,
589     ) -> Result<ProposalMessageDescription, MlsError> {
590         let proposal_ref =
591             ProposalRef::from_content(self.cipher_suite_provider(), auth_content).await?;
592 
593         let group_state = self.group_state_mut();
594 
595         if cache_proposal {
596             let proposal_ref = proposal_ref.clone();
597 
598             group_state.proposals.insert(
599                 proposal_ref.clone(),
600                 proposal.clone(),
601                 auth_content.content.sender,
602             );
603         }
604 
605         Ok(ProposalMessageDescription {
606             authenticated_data: auth_content.content.authenticated_data.clone(),
607             proposal: proposal.clone(),
608             sender: auth_content.content.sender.try_into()?,
609             proposal_ref,
610         })
611     }
612 
613     #[cfg(feature = "state_update")]
make_state_update( &self, provisional: &ProvisionalState, path: Option<&UpdatePath>, sender: LeafIndex, ) -> Result<StateUpdate, MlsError>614     async fn make_state_update(
615         &self,
616         provisional: &ProvisionalState,
617         path: Option<&UpdatePath>,
618         sender: LeafIndex,
619     ) -> Result<StateUpdate, MlsError> {
620         let added = provisional
621             .applied_proposals
622             .additions
623             .iter()
624             .zip(provisional.indexes_of_added_kpkgs.iter())
625             .map(|(p, index)| member_from_key_package(&p.proposal.key_package, *index))
626             .collect::<Vec<_>>();
627 
628         let mut added = added;
629 
630         let old_tree = &self.group_state().public_tree;
631 
632         let removed = provisional
633             .applied_proposals
634             .removals
635             .iter()
636             .map(|p| {
637                 let index = p.proposal.to_remove;
638                 let node = old_tree.nodes.borrow_as_leaf(index)?;
639                 Ok(member_from_leaf_node(node, index))
640             })
641             .collect::<Result<_, MlsError>>()?;
642 
643         #[cfg(feature = "by_ref_proposal")]
644         let mut updated = provisional
645             .applied_proposals
646             .update_senders
647             .iter()
648             .map(|index| {
649                 let prior = old_tree
650                     .get_leaf_node(*index)
651                     .map(|n| member_from_leaf_node(n, *index))?;
652 
653                 let new = provisional
654                     .public_tree
655                     .get_leaf_node(*index)
656                     .map(|n| member_from_leaf_node(n, *index))?;
657 
658                 Ok::<_, MlsError>(MemberUpdate::new(prior, new))
659             })
660             .collect::<Result<Vec<_>, _>>()?;
661 
662         #[cfg(not(feature = "by_ref_proposal"))]
663         let mut updated = Vec::new();
664 
665         if let Some(path) = path {
666             if !provisional
667                 .applied_proposals
668                 .external_initializations
669                 .is_empty()
670             {
671                 added.push(member_from_leaf_node(&path.leaf_node, sender))
672             } else {
673                 let prior = old_tree
674                     .get_leaf_node(sender)
675                     .map(|n| member_from_leaf_node(n, sender))?;
676 
677                 let new = member_from_leaf_node(&path.leaf_node, sender);
678 
679                 updated.push(MemberUpdate::new(prior, new))
680             }
681         }
682 
683         #[cfg(feature = "psk")]
684         let psks = provisional
685             .applied_proposals
686             .psks
687             .iter()
688             .filter_map(|psk| psk.proposal.external_psk_id().cloned())
689             .collect::<Vec<_>>();
690 
691         let roster_update = RosterUpdate::new(added, removed, updated);
692 
693         let update = StateUpdate {
694             roster_update,
695             #[cfg(feature = "psk")]
696             added_psks: psks,
697             pending_reinit: provisional
698                 .applied_proposals
699                 .reinitializations
700                 .first()
701                 .map(|ri| ri.proposal.new_cipher_suite()),
702             active: true,
703             epoch: provisional.group_context.epoch,
704             #[cfg(feature = "custom_proposal")]
705             custom_proposals: provisional.applied_proposals.custom_proposals.clone(),
706             #[cfg(feature = "by_ref_proposal")]
707             unused_proposals: provisional.unused_proposals.clone(),
708         };
709 
710         Ok(update)
711     }
712 
process_commit( &mut self, auth_content: AuthenticatedContent, time_sent: Option<MlsTime>, ) -> Result<CommitMessageDescription, MlsError>713     async fn process_commit(
714         &mut self,
715         auth_content: AuthenticatedContent,
716         time_sent: Option<MlsTime>,
717     ) -> Result<CommitMessageDescription, MlsError> {
718         if self.group_state().pending_reinit.is_some() {
719             return Err(MlsError::GroupUsedAfterReInit);
720         }
721 
722         // Update the new GroupContext's confirmed and interim transcript hashes using the new Commit.
723         let (interim_transcript_hash, confirmed_transcript_hash) = transcript_hashes(
724             self.cipher_suite_provider(),
725             &self.group_state().interim_transcript_hash,
726             &auth_content,
727         )
728         .await?;
729 
730         #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
731         let commit = match auth_content.content.content {
732             Content::Commit(commit) => Ok(commit),
733             _ => Err(MlsError::UnexpectedMessageType),
734         }?;
735 
736         #[cfg(not(any(feature = "private_message", feature = "by_ref_proposal")))]
737         let Content::Commit(commit) = auth_content.content.content;
738 
739         let group_state = self.group_state();
740         let id_provider = self.identity_provider();
741 
742         #[cfg(feature = "by_ref_proposal")]
743         let proposals = group_state
744             .proposals
745             .resolve_for_commit(auth_content.content.sender, commit.proposals)?;
746 
747         #[cfg(not(feature = "by_ref_proposal"))]
748         let proposals = resolve_for_commit(auth_content.content.sender, commit.proposals)?;
749 
750         let mut provisional_state = group_state
751             .apply_resolved(
752                 auth_content.content.sender,
753                 proposals,
754                 commit.path.as_ref().map(|path| &path.leaf_node),
755                 &id_provider,
756                 self.cipher_suite_provider(),
757                 &self.psk_storage(),
758                 &self.mls_rules(),
759                 time_sent,
760                 CommitDirection::Receive,
761             )
762             .await?;
763 
764         let sender = commit_sender(&auth_content.content.sender, &provisional_state)?;
765 
766         #[cfg(feature = "state_update")]
767         let mut state_update = self
768             .make_state_update(&provisional_state, commit.path.as_ref(), sender)
769             .await?;
770 
771         #[cfg(not(feature = "state_update"))]
772         let state_update = StateUpdate {};
773 
774         //Verify that the path value is populated if the proposals vector contains any Update
775         // or Remove proposals, or if it's empty. Otherwise, the path value MAY be omitted.
776         if path_update_required(&provisional_state.applied_proposals) && commit.path.is_none() {
777             return Err(MlsError::CommitMissingPath);
778         }
779 
780         if !self.can_continue_processing(&provisional_state) {
781             #[cfg(feature = "state_update")]
782             {
783                 state_update.active = false;
784             }
785 
786             return Ok(CommitMessageDescription {
787                 is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
788                 authenticated_data: auth_content.content.authenticated_data,
789                 committer: *sender,
790                 state_update,
791             });
792         }
793 
794         let update_path = match commit.path {
795             Some(update_path) => Some(
796                 validate_update_path(
797                     &self.identity_provider(),
798                     self.cipher_suite_provider(),
799                     update_path,
800                     &provisional_state,
801                     sender,
802                     time_sent,
803                 )
804                 .await?,
805             ),
806             None => None,
807         };
808 
809         let new_secrets = match update_path {
810             Some(update_path) => {
811                 self.apply_update_path(sender, &update_path, &mut provisional_state)
812                     .await
813             }
814             None => Ok(None),
815         }?;
816 
817         // Update the transcript hash to get the new context.
818         provisional_state.group_context.confirmed_transcript_hash = confirmed_transcript_hash;
819 
820         // Update the parent hashes in the new context
821         provisional_state
822             .public_tree
823             .update_hashes(&[sender], self.cipher_suite_provider())
824             .await?;
825 
826         // Update the tree hash in the new context
827         provisional_state.group_context.tree_hash = provisional_state
828             .public_tree
829             .tree_hash(self.cipher_suite_provider())
830             .await?;
831 
832         if let Some(reinit) = provisional_state.applied_proposals.reinitializations.pop() {
833             self.group_state_mut().pending_reinit = Some(reinit.proposal);
834 
835             #[cfg(feature = "state_update")]
836             {
837                 state_update.active = false;
838             }
839         }
840 
841         if let Some(confirmation_tag) = &auth_content.auth.confirmation_tag {
842             // Update the key schedule to calculate new private keys
843             self.update_key_schedule(
844                 new_secrets,
845                 interim_transcript_hash,
846                 confirmation_tag,
847                 provisional_state,
848             )
849             .await?;
850 
851             Ok(CommitMessageDescription {
852                 is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
853                 authenticated_data: auth_content.content.authenticated_data,
854                 committer: *sender,
855                 state_update,
856             })
857         } else {
858             Err(MlsError::InvalidConfirmationTag)
859         }
860     }
861 
group_state(&self) -> &GroupState862     fn group_state(&self) -> &GroupState;
group_state_mut(&mut self) -> &mut GroupState863     fn group_state_mut(&mut self) -> &mut GroupState;
864     #[cfg(feature = "private_message")]
self_index(&self) -> Option<LeafIndex>865     fn self_index(&self) -> Option<LeafIndex>;
mls_rules(&self) -> Self::MlsRules866     fn mls_rules(&self) -> Self::MlsRules;
identity_provider(&self) -> Self::IdentityProvider867     fn identity_provider(&self) -> Self::IdentityProvider;
cipher_suite_provider(&self) -> &Self::CipherSuiteProvider868     fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider;
psk_storage(&self) -> Self::PreSharedKeyStorage869     fn psk_storage(&self) -> Self::PreSharedKeyStorage;
can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool870     fn can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool;
871 
872     #[cfg(feature = "private_message")]
min_epoch_available(&self) -> Option<u64>873     fn min_epoch_available(&self) -> Option<u64>;
874 
check_metadata(&self, message: &MlsMessage) -> Result<(), MlsError>875     fn check_metadata(&self, message: &MlsMessage) -> Result<(), MlsError> {
876         let context = &self.group_state().context;
877 
878         if message.version != context.protocol_version {
879             return Err(MlsError::ProtocolVersionMismatch);
880         }
881 
882         if let Some((group_id, epoch, content_type)) = match &message.payload {
883             MlsMessagePayload::Plain(plaintext) => Some((
884                 &plaintext.content.group_id,
885                 plaintext.content.epoch,
886                 plaintext.content.content_type(),
887             )),
888             #[cfg(feature = "private_message")]
889             MlsMessagePayload::Cipher(ciphertext) => Some((
890                 &ciphertext.group_id,
891                 ciphertext.epoch,
892                 ciphertext.content_type,
893             )),
894             _ => None,
895         } {
896             if group_id != &context.group_id {
897                 return Err(MlsError::GroupIdMismatch);
898             }
899 
900             match content_type {
901                 ContentType::Commit => {
902                     if context.epoch != epoch {
903                         Err(MlsError::InvalidEpoch)
904                     } else {
905                         Ok(())
906                     }
907                 }
908                 #[cfg(feature = "by_ref_proposal")]
909                 ContentType::Proposal => {
910                     if context.epoch != epoch {
911                         Err(MlsError::InvalidEpoch)
912                     } else {
913                         Ok(())
914                     }
915                 }
916                 #[cfg(feature = "private_message")]
917                 ContentType::Application => {
918                     if let Some(min) = self.min_epoch_available() {
919                         if epoch < min {
920                             Err(MlsError::InvalidEpoch)
921                         } else {
922                             Ok(())
923                         }
924                     } else {
925                         Ok(())
926                     }
927                 }
928             }?;
929 
930             // Proposal and commit messages must be sent in the current epoch
931             let check_epoch = content_type == ContentType::Commit;
932 
933             #[cfg(feature = "by_ref_proposal")]
934             let check_epoch = check_epoch || content_type == ContentType::Proposal;
935 
936             if check_epoch && epoch != context.epoch {
937                 return Err(MlsError::InvalidEpoch);
938             }
939 
940             // Unencrypted application messages are not allowed
941             #[cfg(feature = "private_message")]
942             if !matches!(&message.payload, MlsMessagePayload::Cipher(_))
943                 && content_type == ContentType::Application
944             {
945                 return Err(MlsError::UnencryptedApplicationMessage);
946             }
947         }
948 
949         Ok(())
950     }
951 
validate_welcome( &self, welcome: &Welcome, version: ProtocolVersion, ) -> Result<(), MlsError>952     fn validate_welcome(
953         &self,
954         welcome: &Welcome,
955         version: ProtocolVersion,
956     ) -> Result<(), MlsError> {
957         let state = self.group_state();
958 
959         (welcome.cipher_suite == state.context.cipher_suite
960             && version == state.context.protocol_version)
961             .then_some(())
962             .ok_or(MlsError::InvalidWelcomeMessage)
963     }
964 
validate_key_package( &self, key_package: &KeyPackage, version: ProtocolVersion, ) -> Result<(), MlsError>965     async fn validate_key_package(
966         &self,
967         key_package: &KeyPackage,
968         version: ProtocolVersion,
969     ) -> Result<(), MlsError> {
970         let cs = self.cipher_suite_provider();
971         let id = self.identity_provider();
972 
973         validate_key_package(key_package, version, cs, &id).await
974     }
975 
976     #[cfg(feature = "private_message")]
process_ciphertext( &mut self, cipher_text: &PrivateMessage, ) -> Result<EventOrContent<Self::OutputType>, MlsError>977     async fn process_ciphertext(
978         &mut self,
979         cipher_text: &PrivateMessage,
980     ) -> Result<EventOrContent<Self::OutputType>, MlsError>;
981 
verify_plaintext_authentication( &self, message: PublicMessage, ) -> Result<EventOrContent<Self::OutputType>, MlsError>982     async fn verify_plaintext_authentication(
983         &self,
984         message: PublicMessage,
985     ) -> Result<EventOrContent<Self::OutputType>, MlsError>;
986 
apply_update_path( &mut self, sender: LeafIndex, update_path: &ValidatedUpdatePath, provisional_state: &mut ProvisionalState, ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError>987     async fn apply_update_path(
988         &mut self,
989         sender: LeafIndex,
990         update_path: &ValidatedUpdatePath,
991         provisional_state: &mut ProvisionalState,
992     ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
993         provisional_state
994             .public_tree
995             .apply_update_path(
996                 sender,
997                 update_path,
998                 &provisional_state.group_context.extensions,
999                 self.identity_provider(),
1000                 self.cipher_suite_provider(),
1001             )
1002             .await
1003             .map(|_| None)
1004     }
1005 
update_key_schedule( &mut self, secrets: Option<(TreeKemPrivate, PathSecret)>, interim_transcript_hash: InterimTranscriptHash, confirmation_tag: &ConfirmationTag, provisional_public_state: ProvisionalState, ) -> Result<(), MlsError>1006     async fn update_key_schedule(
1007         &mut self,
1008         secrets: Option<(TreeKemPrivate, PathSecret)>,
1009         interim_transcript_hash: InterimTranscriptHash,
1010         confirmation_tag: &ConfirmationTag,
1011         provisional_public_state: ProvisionalState,
1012     ) -> Result<(), MlsError>;
1013 }
1014 
1015 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
validate_key_package<C: CipherSuiteProvider, I: IdentityProvider>( key_package: &KeyPackage, version: ProtocolVersion, cs: &C, id: &I, ) -> Result<(), MlsError>1016 pub(crate) async fn validate_key_package<C: CipherSuiteProvider, I: IdentityProvider>(
1017     key_package: &KeyPackage,
1018     version: ProtocolVersion,
1019     cs: &C,
1020     id: &I,
1021 ) -> Result<(), MlsError> {
1022     let validator = LeafNodeValidator::new(cs, id, None);
1023 
1024     #[cfg(feature = "std")]
1025     let context = Some(MlsTime::now());
1026 
1027     #[cfg(not(feature = "std"))]
1028     let context = None;
1029 
1030     let context = ValidationContext::Add(context);
1031 
1032     validator
1033         .check_if_valid(&key_package.leaf_node, context)
1034         .await?;
1035 
1036     validate_key_package_properties(key_package, version, cs).await?;
1037 
1038     Ok(())
1039 }
1040