1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec;
6 use alloc::vec::Vec;
7 use core::fmt::{self, Debug};
8 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
9 use mls_rs_core::{
10     crypto::{CipherSuiteProvider, SignatureSecretKey},
11     error::IntoAnyError,
12 };
13 
14 use crate::{
15     cipher_suite::CipherSuite,
16     client::MlsError,
17     client_config::ClientConfig,
18     extension::RatchetTreeExt,
19     identity::SigningIdentity,
20     protocol_version::ProtocolVersion,
21     signer::Signable,
22     tree_kem::{
23         kem::TreeKem, node::LeafIndex, path_secret::PathSecret, TreeKemPrivate, UpdatePath,
24     },
25     ExtensionList, MlsRules,
26 };
27 
28 #[cfg(all(not(mls_build_async), feature = "rayon"))]
29 use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
30 
31 use crate::tree_kem::leaf_node::LeafNode;
32 
33 #[cfg(not(feature = "private_message"))]
34 use crate::WireFormat;
35 
36 #[cfg(feature = "psk")]
37 use crate::{
38     group::{JustPreSharedKeyID, PskGroupId, ResumptionPSKUsage, ResumptionPsk},
39     psk::ExternalPskId,
40 };
41 
42 use super::{
43     confirmation_tag::ConfirmationTag,
44     framing::{Content, MlsMessage, MlsMessagePayload, Sender},
45     key_schedule::{KeySchedule, WelcomeSecret},
46     message_processor::{path_update_required, MessageProcessor},
47     message_signature::AuthenticatedContent,
48     mls_rules::CommitDirection,
49     proposal::{Proposal, ProposalOrRef},
50     ConfirmedTranscriptHash, EncryptedGroupSecrets, ExportedTree, Group, GroupContext, GroupInfo,
51     Welcome,
52 };
53 
54 #[cfg(not(feature = "by_ref_proposal"))]
55 use super::proposal_cache::prepare_commit;
56 
57 #[cfg(feature = "custom_proposal")]
58 use super::proposal::CustomProposal;
59 
60 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
61 #[cfg_attr(feature = "arbitrary", derive(mls_rs_core::arbitrary::Arbitrary))]
62 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
63 pub(crate) struct Commit {
64     pub proposals: Vec<ProposalOrRef>,
65     pub path: Option<UpdatePath>,
66 }
67 
68 #[derive(Clone, PartialEq, Debug, MlsEncode, MlsDecode, MlsSize)]
69 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
70 pub(super) struct CommitGeneration {
71     pub content: AuthenticatedContent,
72     pub pending_private_tree: TreeKemPrivate,
73     pub pending_commit_secret: PathSecret,
74     pub commit_message_hash: CommitHash,
75 }
76 
77 #[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
78 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
79 pub(crate) struct CommitHash(
80     #[mls_codec(with = "mls_rs_codec::byte_vec")]
81     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
82     Vec<u8>,
83 );
84 
85 impl Debug for CommitHash {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result86     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87         mls_rs_core::debug::pretty_bytes(&self.0)
88             .named("CommitHash")
89             .fmt(f)
90     }
91 }
92 
93 impl CommitHash {
94     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
compute<CS: CipherSuiteProvider>( cs: &CS, commit: &MlsMessage, ) -> Result<Self, MlsError>95     pub(crate) async fn compute<CS: CipherSuiteProvider>(
96         cs: &CS,
97         commit: &MlsMessage,
98     ) -> Result<Self, MlsError> {
99         cs.hash(&commit.mls_encode_to_vec()?)
100             .await
101             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
102             .map(Self)
103     }
104 }
105 
106 #[cfg_attr(
107     all(feature = "ffi", not(test)),
108     safer_ffi_gen::ffi_type(clone, opaque)
109 )]
110 #[derive(Clone, Debug)]
111 #[non_exhaustive]
112 /// Result of MLS commit operation using
113 /// [`Group::commit`](crate::group::Group::commit) or
114 /// [`CommitBuilder::build`](CommitBuilder::build).
115 pub struct CommitOutput {
116     /// Commit message to send to other group members.
117     pub commit_message: MlsMessage,
118     /// Welcome messages to send to new group members. If the commit does not add members,
119     /// this list is empty. Otherwise, if [`MlsRules::commit_options`] returns `single_welcome_message`
120     /// set to true, then this list contains a single message sent to all members. Else, the list
121     /// contains one message for each added member. Recipients of each message can be identified using
122     /// [`MlsMessage::key_package_reference`] of their key packages and
123     /// [`MlsMessage::welcome_key_package_references`].
124     pub welcome_messages: Vec<MlsMessage>,
125     /// Ratchet tree that can be sent out of band if
126     /// `ratchet_tree_extension` is not used according to
127     /// [`MlsRules::commit_options`].
128     pub ratchet_tree: Option<ExportedTree<'static>>,
129     /// A group info that can be provided to new members in order to enable external commit
130     /// functionality. This value is set if [`MlsRules::commit_options`] returns
131     /// `allow_external_commit` set to true.
132     pub external_commit_group_info: Option<MlsMessage>,
133     /// Proposals that were received in the prior epoch but not included in the following commit.
134     #[cfg(feature = "by_ref_proposal")]
135     pub unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
136 }
137 
138 #[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
139 impl CommitOutput {
140     /// Commit message to send to other group members.
141     #[cfg(feature = "ffi")]
commit_message(&self) -> &MlsMessage142     pub fn commit_message(&self) -> &MlsMessage {
143         &self.commit_message
144     }
145 
146     /// Welcome message to send to new group members.
147     #[cfg(feature = "ffi")]
welcome_messages(&self) -> &[MlsMessage]148     pub fn welcome_messages(&self) -> &[MlsMessage] {
149         &self.welcome_messages
150     }
151 
152     /// Ratchet tree that can be sent out of band if
153     /// `ratchet_tree_extension` is not used according to
154     /// [`MlsRules::commit_options`].
155     #[cfg(feature = "ffi")]
ratchet_tree(&self) -> Option<&ExportedTree<'static>>156     pub fn ratchet_tree(&self) -> Option<&ExportedTree<'static>> {
157         self.ratchet_tree.as_ref()
158     }
159 
160     /// A group info that can be provided to new members in order to enable external commit
161     /// functionality. This value is set if [`MlsRules::commit_options`] returns
162     /// `allow_external_commit` set to true.
163     #[cfg(feature = "ffi")]
external_commit_group_info(&self) -> Option<&MlsMessage>164     pub fn external_commit_group_info(&self) -> Option<&MlsMessage> {
165         self.external_commit_group_info.as_ref()
166     }
167 
168     /// Proposals that were received in the prior epoch but not included in the following commit.
169     #[cfg(all(feature = "ffi", feature = "by_ref_proposal"))]
unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>]170     pub fn unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>] {
171         &self.unused_proposals
172     }
173 }
174 
175 /// Build a commit with multiple proposals by-value.
176 ///
177 /// Proposals within a commit can be by-value or by-reference.
178 /// Proposals received during the current epoch will be added to the resulting
179 /// commit by-reference automatically so long as they pass the rules defined
180 /// in the current
181 /// [proposal rules](crate::client_builder::ClientBuilder::mls_rules).
182 pub struct CommitBuilder<'a, C>
183 where
184     C: ClientConfig + Clone,
185 {
186     group: &'a mut Group<C>,
187     pub(super) proposals: Vec<Proposal>,
188     authenticated_data: Vec<u8>,
189     group_info_extensions: ExtensionList,
190     new_signer: Option<SignatureSecretKey>,
191     new_signing_identity: Option<SigningIdentity>,
192 }
193 
194 impl<'a, C> CommitBuilder<'a, C>
195 where
196     C: ClientConfig + Clone,
197 {
198     /// Insert an [`AddProposal`](crate::group::proposal::AddProposal) into
199     /// the current commit that is being built.
add_member(mut self, key_package: MlsMessage) -> Result<CommitBuilder<'a, C>, MlsError>200     pub fn add_member(mut self, key_package: MlsMessage) -> Result<CommitBuilder<'a, C>, MlsError> {
201         let proposal = self.group.add_proposal(key_package)?;
202         self.proposals.push(proposal);
203         Ok(self)
204     }
205 
206     /// Set group info extensions that will be inserted into the resulting
207     /// [welcome messages](CommitOutput::welcome_messages) for new members.
208     ///
209     /// Group info extensions that are transmitted as part of a welcome message
210     /// are encrypted along with other private values.
211     ///
212     /// These extensions can be retrieved as part of
213     /// [`NewMemberInfo`](crate::group::NewMemberInfo) that is returned
214     /// by joining the group via
215     /// [`Client::join_group`](crate::Client::join_group).
set_group_info_ext(self, extensions: ExtensionList) -> Self216     pub fn set_group_info_ext(self, extensions: ExtensionList) -> Self {
217         Self {
218             group_info_extensions: extensions,
219             ..self
220         }
221     }
222 
223     /// Insert a [`RemoveProposal`](crate::group::proposal::RemoveProposal) into
224     /// the current commit that is being built.
remove_member(mut self, index: u32) -> Result<Self, MlsError>225     pub fn remove_member(mut self, index: u32) -> Result<Self, MlsError> {
226         let proposal = self.group.remove_proposal(index)?;
227         self.proposals.push(proposal);
228         Ok(self)
229     }
230 
231     /// Insert a
232     /// [`GroupContextExtensions`](crate::group::proposal::Proposal::GroupContextExtensions)
233     /// into the current commit that is being built.
set_group_context_ext(mut self, extensions: ExtensionList) -> Result<Self, MlsError>234     pub fn set_group_context_ext(mut self, extensions: ExtensionList) -> Result<Self, MlsError> {
235         let proposal = self.group.group_context_extensions_proposal(extensions);
236         self.proposals.push(proposal);
237         Ok(self)
238     }
239 
240     /// Insert a
241     /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with
242     /// an external PSK into the current commit that is being built.
243     #[cfg(feature = "psk")]
add_external_psk(mut self, psk_id: ExternalPskId) -> Result<Self, MlsError>244     pub fn add_external_psk(mut self, psk_id: ExternalPskId) -> Result<Self, MlsError> {
245         let key_id = JustPreSharedKeyID::External(psk_id);
246         let proposal = self.group.psk_proposal(key_id)?;
247         self.proposals.push(proposal);
248         Ok(self)
249     }
250 
251     /// Insert a
252     /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with
253     /// a resumption PSK into the current commit that is being built.
254     #[cfg(feature = "psk")]
add_resumption_psk(mut self, psk_epoch: u64) -> Result<Self, MlsError>255     pub fn add_resumption_psk(mut self, psk_epoch: u64) -> Result<Self, MlsError> {
256         let psk_id = ResumptionPsk {
257             psk_epoch,
258             usage: ResumptionPSKUsage::Application,
259             psk_group_id: PskGroupId(self.group.group_id().to_vec()),
260         };
261 
262         let key_id = JustPreSharedKeyID::Resumption(psk_id);
263         let proposal = self.group.psk_proposal(key_id)?;
264         self.proposals.push(proposal);
265         Ok(self)
266     }
267 
268     /// Insert a [`ReInitProposal`](crate::group::proposal::ReInitProposal) into
269     /// the current commit that is being built.
reinit( mut self, group_id: Option<Vec<u8>>, version: ProtocolVersion, cipher_suite: CipherSuite, extensions: ExtensionList, ) -> Result<Self, MlsError>270     pub fn reinit(
271         mut self,
272         group_id: Option<Vec<u8>>,
273         version: ProtocolVersion,
274         cipher_suite: CipherSuite,
275         extensions: ExtensionList,
276     ) -> Result<Self, MlsError> {
277         let proposal = self
278             .group
279             .reinit_proposal(group_id, version, cipher_suite, extensions)?;
280 
281         self.proposals.push(proposal);
282         Ok(self)
283     }
284 
285     /// Insert a [`CustomProposal`](crate::group::proposal::CustomProposal) into
286     /// the current commit that is being built.
287     #[cfg(feature = "custom_proposal")]
custom_proposal(mut self, proposal: CustomProposal) -> Self288     pub fn custom_proposal(mut self, proposal: CustomProposal) -> Self {
289         self.proposals.push(Proposal::Custom(proposal));
290         self
291     }
292 
293     /// Insert a proposal that was previously constructed such as when a
294     /// proposal is returned from
295     /// [`StateUpdate::unused_proposals`](super::StateUpdate::unused_proposals).
raw_proposal(mut self, proposal: Proposal) -> Self296     pub fn raw_proposal(mut self, proposal: Proposal) -> Self {
297         self.proposals.push(proposal);
298         self
299     }
300 
301     /// Insert proposals that were previously constructed such as when a
302     /// proposal is returned from
303     /// [`StateUpdate::unused_proposals`](super::StateUpdate::unused_proposals).
raw_proposals(mut self, mut proposals: Vec<Proposal>) -> Self304     pub fn raw_proposals(mut self, mut proposals: Vec<Proposal>) -> Self {
305         self.proposals.append(&mut proposals);
306         self
307     }
308 
309     /// Add additional authenticated data to the commit.
310     ///
311     /// # Warning
312     ///
313     /// The data provided here is always sent unencrypted.
authenticated_data(self, authenticated_data: Vec<u8>) -> Self314     pub fn authenticated_data(self, authenticated_data: Vec<u8>) -> Self {
315         Self {
316             authenticated_data,
317             ..self
318         }
319     }
320 
321     /// Change the committer's signing identity as part of making this commit.
322     /// This will only succeed if the [`IdentityProvider`](crate::IdentityProvider)
323     /// in use by the group considers the credential inside this signing_identity
324     /// [valid](crate::IdentityProvider::validate_member)
325     /// and results in the same
326     /// [identity](crate::IdentityProvider::identity)
327     /// being used.
set_new_signing_identity( self, signer: SignatureSecretKey, signing_identity: SigningIdentity, ) -> Self328     pub fn set_new_signing_identity(
329         self,
330         signer: SignatureSecretKey,
331         signing_identity: SigningIdentity,
332     ) -> Self {
333         Self {
334             new_signer: Some(signer),
335             new_signing_identity: Some(signing_identity),
336             ..self
337         }
338     }
339 
340     /// Finalize the commit to send.
341     ///
342     /// # Errors
343     ///
344     /// This function will return an error if any of the proposals provided
345     /// are not contextually valid according to the rules defined by the
346     /// MLS RFC, or if they do not pass the custom rules defined by the current
347     /// [proposal rules](crate::client_builder::ClientBuilder::mls_rules).
348     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
build(self) -> Result<CommitOutput, MlsError>349     pub async fn build(self) -> Result<CommitOutput, MlsError> {
350         self.group
351             .commit_internal(
352                 self.proposals,
353                 None,
354                 self.authenticated_data,
355                 self.group_info_extensions,
356                 self.new_signer,
357                 self.new_signing_identity,
358             )
359             .await
360     }
361 }
362 
363 impl<C> Group<C>
364 where
365     C: ClientConfig + Clone,
366 {
367     /// Perform a commit of received proposals.
368     ///
369     /// This function is the equivalent of [`Group::commit_builder`] immediately
370     /// followed by [`CommitBuilder::build`]. Any received proposals since the
371     /// last commit will be included in the resulting message by-reference.
372     ///
373     /// Data provided in the `authenticated_data` field will be placed into
374     /// the resulting commit message unencrypted.
375     ///
376     /// # Pending Commits
377     ///
378     /// When a commit is created, it is not applied immediately in order to
379     /// allow for the resolution of conflicts when multiple members of a group
380     /// attempt to make commits at the same time. For example, a central relay
381     /// can be used to decide which commit should be accepted by the group by
382     /// determining a consistent view of commit packet order for all clients.
383     ///
384     /// Pending commits are stored internally as part of the group's state
385     /// so they do not need to be tracked outside of this library. Any commit
386     /// message that is processed before calling [Group::apply_pending_commit]
387     /// will clear the currently pending commit.
388     ///
389     /// # Empty Commits
390     ///
391     /// Sending a commit that contains no proposals is a valid operation
392     /// within the MLS protocol. It is useful for providing stronger forward
393     /// secrecy and post-compromise security, especially for long running
394     /// groups when group membership does not change often.
395     ///
396     /// # Path Updates
397     ///
398     /// Path updates provide forward secrecy and post-compromise security
399     /// within the MLS protocol.
400     /// The `path_required` option returned by [`MlsRules::commit_options`](`crate::MlsRules::commit_options`)
401     /// controls the ability of a group to send a commit without a path update.
402     /// An update path will automatically be sent if there are no proposals
403     /// in the commit, or if any proposal other than
404     /// [`Add`](crate::group::proposal::Proposal::Add),
405     /// [`Psk`](crate::group::proposal::Proposal::Psk),
406     /// or [`ReInit`](crate::group::proposal::Proposal::ReInit) are part of the commit.
407     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
commit(&mut self, authenticated_data: Vec<u8>) -> Result<CommitOutput, MlsError>408     pub async fn commit(&mut self, authenticated_data: Vec<u8>) -> Result<CommitOutput, MlsError> {
409         self.commit_internal(
410             vec![],
411             None,
412             authenticated_data,
413             Default::default(),
414             None,
415             None,
416         )
417         .await
418     }
419 
420     /// Create a new commit builder that can include proposals
421     /// by-value.
commit_builder(&mut self) -> CommitBuilder<C>422     pub fn commit_builder(&mut self) -> CommitBuilder<C> {
423         CommitBuilder {
424             group: self,
425             proposals: Default::default(),
426             authenticated_data: Default::default(),
427             group_info_extensions: Default::default(),
428             new_signer: Default::default(),
429             new_signing_identity: Default::default(),
430         }
431     }
432 
433     /// Returns commit and optional [`MlsMessage`] containing a welcome message
434     /// for newly added members.
435     #[allow(clippy::too_many_arguments)]
436     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
commit_internal( &mut self, proposals: Vec<Proposal>, external_leaf: Option<&LeafNode>, authenticated_data: Vec<u8>, mut welcome_group_info_extensions: ExtensionList, new_signer: Option<SignatureSecretKey>, new_signing_identity: Option<SigningIdentity>, ) -> Result<CommitOutput, MlsError>437     pub(super) async fn commit_internal(
438         &mut self,
439         proposals: Vec<Proposal>,
440         external_leaf: Option<&LeafNode>,
441         authenticated_data: Vec<u8>,
442         mut welcome_group_info_extensions: ExtensionList,
443         new_signer: Option<SignatureSecretKey>,
444         new_signing_identity: Option<SigningIdentity>,
445     ) -> Result<CommitOutput, MlsError> {
446         if self.pending_commit.is_some() {
447             return Err(MlsError::ExistingPendingCommit);
448         }
449 
450         if self.state.pending_reinit.is_some() {
451             return Err(MlsError::GroupUsedAfterReInit);
452         }
453 
454         let mls_rules = self.config.mls_rules();
455 
456         let is_external = external_leaf.is_some();
457 
458         // Construct an initial Commit object with the proposals field populated from Proposals
459         // received during the current epoch, and an empty path field. Add passed in proposals
460         // by value
461         let sender = if is_external {
462             Sender::NewMemberCommit
463         } else {
464             Sender::Member(*self.private_tree.self_index)
465         };
466 
467         let new_signer_ref = new_signer.as_ref().unwrap_or(&self.signer);
468         let old_signer = &self.signer;
469 
470         #[cfg(feature = "std")]
471         let time = Some(crate::time::MlsTime::now());
472 
473         #[cfg(not(feature = "std"))]
474         let time = None;
475 
476         #[cfg(feature = "by_ref_proposal")]
477         let proposals = self.state.proposals.prepare_commit(sender, proposals);
478 
479         #[cfg(not(feature = "by_ref_proposal"))]
480         let proposals = prepare_commit(sender, proposals);
481 
482         let mut provisional_state = self
483             .state
484             .apply_resolved(
485                 sender,
486                 proposals,
487                 external_leaf,
488                 &self.config.identity_provider(),
489                 &self.cipher_suite_provider,
490                 &self.config.secret_store(),
491                 &mls_rules,
492                 time,
493                 CommitDirection::Send,
494             )
495             .await?;
496 
497         let (mut provisional_private_tree, _) =
498             self.provisional_private_tree(&provisional_state)?;
499 
500         if is_external {
501             provisional_private_tree.self_index = provisional_state
502                 .external_init_index
503                 .ok_or(MlsError::ExternalCommitMissingExternalInit)?;
504 
505             self.private_tree.self_index = provisional_private_tree.self_index;
506         }
507 
508         let mut provisional_group_context = provisional_state.group_context;
509 
510         // Decide whether to populate the path field: If the path field is required based on the
511         // proposals that are in the commit (see above), then it MUST be populated. Otherwise, the
512         // sender MAY omit the path field at its discretion.
513         let commit_options = mls_rules
514             .commit_options(
515                 &provisional_state.public_tree.roster(),
516                 &provisional_group_context.extensions,
517                 &provisional_state.applied_proposals,
518             )
519             .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?;
520 
521         let perform_path_update = commit_options.path_required
522             || path_update_required(&provisional_state.applied_proposals);
523 
524         let (update_path, path_secrets, commit_secret) = if perform_path_update {
525             // If populating the path field: Create an UpdatePath using the new tree. Any new
526             // member (from an add proposal) MUST be excluded from the resolution during the
527             // computation of the UpdatePath. The GroupContext for this operation uses the
528             // group_id, epoch, tree_hash, and confirmed_transcript_hash values in the initial
529             // GroupContext object. The leaf_key_package for this UpdatePath must have a
530             // parent_hash extension.
531             let encap_gen = TreeKem::new(
532                 &mut provisional_state.public_tree,
533                 &mut provisional_private_tree,
534             )
535             .encap(
536                 &mut provisional_group_context,
537                 &provisional_state.indexes_of_added_kpkgs,
538                 new_signer_ref,
539                 self.config.leaf_properties(),
540                 new_signing_identity,
541                 &self.cipher_suite_provider,
542                 #[cfg(test)]
543                 &self.commit_modifiers,
544             )
545             .await?;
546 
547             (
548                 Some(encap_gen.update_path),
549                 Some(encap_gen.path_secrets),
550                 encap_gen.commit_secret,
551             )
552         } else {
553             // Update the tree hash, since it was not updated by encap.
554             provisional_state
555                 .public_tree
556                 .update_hashes(
557                     &[provisional_private_tree.self_index],
558                     &self.cipher_suite_provider,
559                 )
560                 .await?;
561 
562             provisional_group_context.tree_hash = provisional_state
563                 .public_tree
564                 .tree_hash(&self.cipher_suite_provider)
565                 .await?;
566 
567             (None, None, PathSecret::empty(&self.cipher_suite_provider))
568         };
569 
570         #[cfg(feature = "psk")]
571         let (psk_secret, psks) = self
572             .get_psk(&provisional_state.applied_proposals.psks)
573             .await?;
574 
575         #[cfg(not(feature = "psk"))]
576         let psk_secret = self.get_psk();
577 
578         let added_key_pkgs: Vec<_> = provisional_state
579             .applied_proposals
580             .additions
581             .iter()
582             .map(|info| info.proposal.key_package.clone())
583             .collect();
584 
585         let commit = Commit {
586             proposals: provisional_state.applied_proposals.into_proposals_or_refs(),
587             path: update_path,
588         };
589 
590         let mut auth_content = AuthenticatedContent::new_signed(
591             &self.cipher_suite_provider,
592             self.context(),
593             sender,
594             Content::Commit(alloc::boxed::Box::new(commit)),
595             old_signer,
596             #[cfg(feature = "private_message")]
597             self.encryption_options()?.control_wire_format(sender),
598             #[cfg(not(feature = "private_message"))]
599             WireFormat::PublicMessage,
600             authenticated_data,
601         )
602         .await?;
603 
604         // Use the signature, the commit_secret and the psk_secret to advance the key schedule and
605         // compute the confirmation_tag value in the MlsPlaintext.
606         let confirmed_transcript_hash = ConfirmedTranscriptHash::create(
607             self.cipher_suite_provider(),
608             &self.state.interim_transcript_hash,
609             &auth_content,
610         )
611         .await?;
612 
613         provisional_group_context.confirmed_transcript_hash = confirmed_transcript_hash;
614 
615         let key_schedule_result = KeySchedule::from_key_schedule(
616             &self.key_schedule,
617             &commit_secret,
618             &provisional_group_context,
619             #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
620             self.state.public_tree.total_leaf_count(),
621             &psk_secret,
622             &self.cipher_suite_provider,
623         )
624         .await?;
625 
626         let confirmation_tag = ConfirmationTag::create(
627             &key_schedule_result.confirmation_key,
628             &provisional_group_context.confirmed_transcript_hash,
629             &self.cipher_suite_provider,
630         )
631         .await?;
632 
633         auth_content.auth.confirmation_tag = Some(confirmation_tag.clone());
634 
635         let ratchet_tree_ext = commit_options
636             .ratchet_tree_extension
637             .then(|| RatchetTreeExt {
638                 tree_data: ExportedTree::new(provisional_state.public_tree.nodes.clone()),
639             });
640 
641         // Generate external commit group info if required by commit_options
642         let external_commit_group_info = match commit_options.allow_external_commit {
643             true => {
644                 let mut extensions = ExtensionList::new();
645 
646                 extensions.set_from({
647                     key_schedule_result
648                         .key_schedule
649                         .get_external_key_pair_ext(&self.cipher_suite_provider)
650                         .await?
651                 })?;
652 
653                 if let Some(ref ratchet_tree_ext) = ratchet_tree_ext {
654                     extensions.set_from(ratchet_tree_ext.clone())?;
655                 }
656 
657                 let info = self
658                     .make_group_info(
659                         &provisional_group_context,
660                         extensions,
661                         &confirmation_tag,
662                         new_signer_ref,
663                     )
664                     .await?;
665 
666                 let msg =
667                     MlsMessage::new(self.protocol_version(), MlsMessagePayload::GroupInfo(info));
668 
669                 Some(msg)
670             }
671             false => None,
672         };
673 
674         // Build the group info that will be placed into the welcome messages.
675         // Add the ratchet tree extension if necessary
676         if let Some(ratchet_tree_ext) = ratchet_tree_ext {
677             welcome_group_info_extensions.set_from(ratchet_tree_ext)?;
678         }
679 
680         let welcome_group_info = self
681             .make_group_info(
682                 &provisional_group_context,
683                 welcome_group_info_extensions,
684                 &confirmation_tag,
685                 new_signer_ref,
686             )
687             .await?;
688 
689         // Encrypt the GroupInfo using the key and nonce derived from the joiner_secret for
690         // the new epoch
691         let welcome_secret = WelcomeSecret::from_joiner_secret(
692             &self.cipher_suite_provider,
693             &key_schedule_result.joiner_secret,
694             &psk_secret,
695         )
696         .await?;
697 
698         let encrypted_group_info = welcome_secret
699             .encrypt(&welcome_group_info.mls_encode_to_vec()?)
700             .await?;
701 
702         // Encrypt path secrets and joiner secret to new members
703         let path_secrets = path_secrets.as_ref();
704 
705         #[cfg(not(any(mls_build_async, not(feature = "rayon"))))]
706         let encrypted_path_secrets: Vec<_> = added_key_pkgs
707             .into_par_iter()
708             .zip(provisional_state.indexes_of_added_kpkgs)
709             .map(|(key_package, leaf_index)| {
710                 self.encrypt_group_secrets(
711                     &key_package,
712                     leaf_index,
713                     &key_schedule_result.joiner_secret,
714                     path_secrets,
715                     #[cfg(feature = "psk")]
716                     psks.clone(),
717                     &encrypted_group_info,
718                 )
719             })
720             .try_collect()?;
721 
722         #[cfg(any(mls_build_async, not(feature = "rayon")))]
723         let encrypted_path_secrets = {
724             let mut secrets = Vec::new();
725 
726             for (key_package, leaf_index) in added_key_pkgs
727                 .into_iter()
728                 .zip(provisional_state.indexes_of_added_kpkgs)
729             {
730                 secrets.push(
731                     self.encrypt_group_secrets(
732                         &key_package,
733                         leaf_index,
734                         &key_schedule_result.joiner_secret,
735                         path_secrets,
736                         #[cfg(feature = "psk")]
737                         psks.clone(),
738                         &encrypted_group_info,
739                     )
740                     .await?,
741                 );
742             }
743 
744             secrets
745         };
746 
747         let welcome_messages =
748             if commit_options.single_welcome_message && !encrypted_path_secrets.is_empty() {
749                 vec![self.make_welcome_message(encrypted_path_secrets, encrypted_group_info)]
750             } else {
751                 encrypted_path_secrets
752                     .into_iter()
753                     .map(|s| self.make_welcome_message(vec![s], encrypted_group_info.clone()))
754                     .collect()
755             };
756 
757         let commit_message = self.format_for_wire(auth_content.clone()).await?;
758 
759         let pending_commit = CommitGeneration {
760             content: auth_content,
761             pending_private_tree: provisional_private_tree,
762             pending_commit_secret: commit_secret,
763             commit_message_hash: CommitHash::compute(&self.cipher_suite_provider, &commit_message)
764                 .await?,
765         };
766 
767         self.pending_commit = Some(pending_commit);
768 
769         let ratchet_tree = (!commit_options.ratchet_tree_extension)
770             .then(|| ExportedTree::new(provisional_state.public_tree.nodes));
771 
772         if let Some(signer) = new_signer {
773             self.signer = signer;
774         }
775 
776         Ok(CommitOutput {
777             commit_message,
778             welcome_messages,
779             ratchet_tree,
780             external_commit_group_info,
781             #[cfg(feature = "by_ref_proposal")]
782             unused_proposals: provisional_state.unused_proposals,
783         })
784     }
785 
786     // Construct a GroupInfo reflecting the new state
787     // Group ID, epoch, tree, and confirmed transcript hash from the new state
788     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
make_group_info( &self, group_context: &GroupContext, extensions: ExtensionList, confirmation_tag: &ConfirmationTag, signer: &SignatureSecretKey, ) -> Result<GroupInfo, MlsError>789     async fn make_group_info(
790         &self,
791         group_context: &GroupContext,
792         extensions: ExtensionList,
793         confirmation_tag: &ConfirmationTag,
794         signer: &SignatureSecretKey,
795     ) -> Result<GroupInfo, MlsError> {
796         let mut group_info = GroupInfo {
797             group_context: group_context.clone(),
798             extensions,
799             confirmation_tag: confirmation_tag.clone(), // The confirmation_tag from the MlsPlaintext object
800             signer: LeafIndex(self.current_member_index()),
801             signature: vec![],
802         };
803 
804         group_info.grease(self.cipher_suite_provider())?;
805 
806         // Sign the GroupInfo using the member's private signing key
807         group_info
808             .sign(&self.cipher_suite_provider, signer, &())
809             .await?;
810 
811         Ok(group_info)
812     }
813 
make_welcome_message( &self, secrets: Vec<EncryptedGroupSecrets>, encrypted_group_info: Vec<u8>, ) -> MlsMessage814     fn make_welcome_message(
815         &self,
816         secrets: Vec<EncryptedGroupSecrets>,
817         encrypted_group_info: Vec<u8>,
818     ) -> MlsMessage {
819         MlsMessage::new(
820             self.context().protocol_version,
821             MlsMessagePayload::Welcome(Welcome {
822                 cipher_suite: self.context().cipher_suite,
823                 secrets,
824                 encrypted_group_info,
825             }),
826         )
827     }
828 }
829 
830 #[cfg(test)]
831 pub(crate) mod test_utils {
832     use alloc::vec::Vec;
833 
834     use crate::{
835         crypto::SignatureSecretKey,
836         tree_kem::{leaf_node::LeafNode, TreeKemPublic, UpdatePathNode},
837     };
838 
839     #[derive(Copy, Clone, Debug)]
840     pub struct CommitModifiers {
841         pub modify_leaf: fn(&mut LeafNode, &SignatureSecretKey) -> Option<SignatureSecretKey>,
842         pub modify_tree: fn(&mut TreeKemPublic),
843         pub modify_path: fn(Vec<UpdatePathNode>) -> Vec<UpdatePathNode>,
844     }
845 
846     impl Default for CommitModifiers {
default() -> Self847         fn default() -> Self {
848             Self {
849                 modify_leaf: |_, _| None,
850                 modify_tree: |_| (),
851                 modify_path: |a| a,
852             }
853         }
854     }
855 }
856 
857 #[cfg(test)]
858 mod tests {
859     use alloc::boxed::Box;
860 
861     use mls_rs_core::{
862         error::IntoAnyError,
863         extension::ExtensionType,
864         identity::{CredentialType, IdentityProvider},
865         time::MlsTime,
866     };
867 
868     use crate::{
869         crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
870         group::{mls_rules::DefaultMlsRules, test_utils::test_group_custom},
871         mls_rules::CommitOptions,
872         Client,
873     };
874 
875     #[cfg(feature = "by_ref_proposal")]
876     use crate::extension::ExternalSendersExt;
877 
878     use crate::{
879         client::test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
880         client_builder::{
881             test_utils::TestClientConfig, BaseConfig, ClientBuilder, WithCryptoProvider,
882             WithIdentityProvider,
883         },
884         client_config::ClientConfig,
885         extension::test_utils::{TestExtension, TEST_EXTENSION_TYPE},
886         group::{
887             proposal::ProposalType,
888             test_utils::{test_group_custom_config, test_n_member_group},
889         },
890         identity::test_utils::get_test_signing_identity,
891         identity::{basic::BasicIdentityProvider, test_utils::get_test_basic_credential},
892         key_package::test_utils::test_key_package_message,
893     };
894 
895     use crate::extension::RequiredCapabilitiesExt;
896 
897     #[cfg(feature = "psk")]
898     use crate::{
899         group::proposal::PreSharedKeyProposal,
900         psk::{JustPreSharedKeyID, PreSharedKey, PreSharedKeyID},
901     };
902 
903     use super::*;
904 
905     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_commit_builder_group() -> Group<TestClientConfig>906     async fn test_commit_builder_group() -> Group<TestClientConfig> {
907         test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
908             b.custom_proposal_type(ProposalType::from(42))
909                 .extension_type(TEST_EXTENSION_TYPE.into())
910         })
911         .await
912         .group
913     }
914 
assert_commit_builder_output<C: ClientConfig>( group: Group<C>, mut commit_output: CommitOutput, expected: Vec<Proposal>, welcome_count: usize, )915     fn assert_commit_builder_output<C: ClientConfig>(
916         group: Group<C>,
917         mut commit_output: CommitOutput,
918         expected: Vec<Proposal>,
919         welcome_count: usize,
920     ) {
921         let plaintext = commit_output.commit_message.into_plaintext().unwrap();
922 
923         let commit_data = match plaintext.content.content {
924             Content::Commit(commit) => commit,
925             #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
926             _ => panic!("Found non-commit data"),
927         };
928 
929         assert_eq!(commit_data.proposals.len(), expected.len());
930 
931         commit_data.proposals.into_iter().for_each(|proposal| {
932             let proposal = match proposal {
933                 ProposalOrRef::Proposal(p) => p,
934                 #[cfg(feature = "by_ref_proposal")]
935                 ProposalOrRef::Reference(_) => panic!("found proposal reference"),
936             };
937 
938             #[cfg(feature = "psk")]
939             if let Some(psk_id) = match proposal.as_ref() {
940                 Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(psk_id), .. },}) => Some(psk_id),
941                 _ => None,
942             } {
943                 let found = expected.iter().any(|item| matches!(item, Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(id), .. }}) if id == psk_id));
944 
945                 assert!(found)
946             } else {
947                 assert!(expected.contains(&proposal));
948             }
949 
950             #[cfg(not(feature = "psk"))]
951             assert!(expected.contains(&proposal));
952         });
953 
954         if welcome_count > 0 {
955             let welcome_msg = commit_output.welcome_messages.pop().unwrap();
956 
957             assert_eq!(welcome_msg.version, group.state.context.protocol_version);
958 
959             let welcome_msg = welcome_msg.into_welcome().unwrap();
960 
961             assert_eq!(welcome_msg.cipher_suite, group.state.context.cipher_suite);
962             assert_eq!(welcome_msg.secrets.len(), welcome_count);
963         } else {
964             assert!(commit_output.welcome_messages.is_empty());
965         }
966     }
967 
968     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_add()969     async fn test_commit_builder_add() {
970         let mut group = test_commit_builder_group().await;
971 
972         let test_key_package =
973             test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
974 
975         let commit_output = group
976             .commit_builder()
977             .add_member(test_key_package.clone())
978             .unwrap()
979             .build()
980             .await
981             .unwrap();
982 
983         let expected_add = group.add_proposal(test_key_package).unwrap();
984 
985         assert_commit_builder_output(group, commit_output, vec![expected_add], 1)
986     }
987 
988     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_add_with_ext()989     async fn test_commit_builder_add_with_ext() {
990         let mut group = test_commit_builder_group().await;
991 
992         let (bob_client, bob_key_package) =
993             test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
994 
995         let ext = TestExtension { foo: 42 };
996         let mut extension_list = ExtensionList::default();
997         extension_list.set_from(ext.clone()).unwrap();
998 
999         let welcome_message = group
1000             .commit_builder()
1001             .add_member(bob_key_package)
1002             .unwrap()
1003             .set_group_info_ext(extension_list)
1004             .build()
1005             .await
1006             .unwrap()
1007             .welcome_messages
1008             .remove(0);
1009 
1010         let (_, context) = bob_client.join_group(None, &welcome_message).await.unwrap();
1011 
1012         assert_eq!(
1013             context
1014                 .group_info_extensions
1015                 .get_as::<TestExtension>()
1016                 .unwrap()
1017                 .unwrap(),
1018             ext
1019         );
1020     }
1021 
1022     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_remove()1023     async fn test_commit_builder_remove() {
1024         let mut group = test_commit_builder_group().await;
1025         let test_key_package =
1026             test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1027 
1028         group
1029             .commit_builder()
1030             .add_member(test_key_package)
1031             .unwrap()
1032             .build()
1033             .await
1034             .unwrap();
1035 
1036         group.apply_pending_commit().await.unwrap();
1037 
1038         let commit_output = group
1039             .commit_builder()
1040             .remove_member(1)
1041             .unwrap()
1042             .build()
1043             .await
1044             .unwrap();
1045 
1046         let expected_remove = group.remove_proposal(1).unwrap();
1047 
1048         assert_commit_builder_output(group, commit_output, vec![expected_remove], 0);
1049     }
1050 
1051     #[cfg(feature = "psk")]
1052     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_psk()1053     async fn test_commit_builder_psk() {
1054         let mut group = test_commit_builder_group().await;
1055         let test_psk = ExternalPskId::new(vec![1]);
1056 
1057         group
1058             .config
1059             .secret_store()
1060             .insert(test_psk.clone(), PreSharedKey::from(vec![1]));
1061 
1062         let commit_output = group
1063             .commit_builder()
1064             .add_external_psk(test_psk.clone())
1065             .unwrap()
1066             .build()
1067             .await
1068             .unwrap();
1069 
1070         let key_id = JustPreSharedKeyID::External(test_psk);
1071         let expected_psk = group.psk_proposal(key_id).unwrap();
1072 
1073         assert_commit_builder_output(group, commit_output, vec![expected_psk], 0)
1074     }
1075 
1076     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_group_context_ext()1077     async fn test_commit_builder_group_context_ext() {
1078         let mut group = test_commit_builder_group().await;
1079         let mut test_ext = ExtensionList::default();
1080         test_ext
1081             .set_from(RequiredCapabilitiesExt::default())
1082             .unwrap();
1083 
1084         let commit_output = group
1085             .commit_builder()
1086             .set_group_context_ext(test_ext.clone())
1087             .unwrap()
1088             .build()
1089             .await
1090             .unwrap();
1091 
1092         let expected_ext = group.group_context_extensions_proposal(test_ext);
1093 
1094         assert_commit_builder_output(group, commit_output, vec![expected_ext], 0);
1095     }
1096 
1097     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_reinit()1098     async fn test_commit_builder_reinit() {
1099         let mut group = test_commit_builder_group().await;
1100         let test_group_id = "foo".as_bytes().to_vec();
1101         let test_cipher_suite = TEST_CIPHER_SUITE;
1102         let test_protocol_version = TEST_PROTOCOL_VERSION;
1103         let mut test_ext = ExtensionList::default();
1104 
1105         test_ext
1106             .set_from(RequiredCapabilitiesExt::default())
1107             .unwrap();
1108 
1109         let commit_output = group
1110             .commit_builder()
1111             .reinit(
1112                 Some(test_group_id.clone()),
1113                 test_protocol_version,
1114                 test_cipher_suite,
1115                 test_ext.clone(),
1116             )
1117             .unwrap()
1118             .build()
1119             .await
1120             .unwrap();
1121 
1122         let expected_reinit = group
1123             .reinit_proposal(
1124                 Some(test_group_id),
1125                 test_protocol_version,
1126                 test_cipher_suite,
1127                 test_ext,
1128             )
1129             .unwrap();
1130 
1131         assert_commit_builder_output(group, commit_output, vec![expected_reinit], 0);
1132     }
1133 
1134     #[cfg(feature = "custom_proposal")]
1135     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_custom_proposal()1136     async fn test_commit_builder_custom_proposal() {
1137         let mut group = test_commit_builder_group().await;
1138 
1139         let proposal = CustomProposal::new(42.into(), vec![0, 1]);
1140 
1141         let commit_output = group
1142             .commit_builder()
1143             .custom_proposal(proposal.clone())
1144             .build()
1145             .await
1146             .unwrap();
1147 
1148         assert_commit_builder_output(group, commit_output, vec![Proposal::Custom(proposal)], 0);
1149     }
1150 
1151     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_chaining()1152     async fn test_commit_builder_chaining() {
1153         let mut group = test_commit_builder_group().await;
1154         let kp1 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1155         let kp2 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
1156 
1157         let expected_adds = vec![
1158             group.add_proposal(kp1.clone()).unwrap(),
1159             group.add_proposal(kp2.clone()).unwrap(),
1160         ];
1161 
1162         let commit_output = group
1163             .commit_builder()
1164             .add_member(kp1)
1165             .unwrap()
1166             .add_member(kp2)
1167             .unwrap()
1168             .build()
1169             .await
1170             .unwrap();
1171 
1172         assert_commit_builder_output(group, commit_output, expected_adds, 2);
1173     }
1174 
1175     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_empty_commit()1176     async fn test_commit_builder_empty_commit() {
1177         let mut group = test_commit_builder_group().await;
1178 
1179         let commit_output = group.commit_builder().build().await.unwrap();
1180 
1181         assert_commit_builder_output(group, commit_output, vec![], 0);
1182     }
1183 
1184     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_authenticated_data()1185     async fn test_commit_builder_authenticated_data() {
1186         let mut group = test_commit_builder_group().await;
1187         let test_data = "test".as_bytes().to_vec();
1188 
1189         let commit_output = group
1190             .commit_builder()
1191             .authenticated_data(test_data.clone())
1192             .build()
1193             .await
1194             .unwrap();
1195 
1196         assert_eq!(
1197             commit_output
1198                 .commit_message
1199                 .into_plaintext()
1200                 .unwrap()
1201                 .content
1202                 .authenticated_data,
1203             test_data
1204         );
1205     }
1206 
1207     #[cfg(feature = "by_ref_proposal")]
1208     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_multiple_welcome_messages()1209     async fn test_commit_builder_multiple_welcome_messages() {
1210         let mut group = test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
1211             let options = CommitOptions::new().with_single_welcome_message(false);
1212             b.mls_rules(DefaultMlsRules::new().with_commit_options(options))
1213         })
1214         .await;
1215 
1216         let (alice, alice_kp) =
1217             test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "a").await;
1218 
1219         let (bob, bob_kp) =
1220             test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "b").await;
1221 
1222         group
1223             .group
1224             .propose_add(alice_kp.clone(), vec![])
1225             .await
1226             .unwrap();
1227 
1228         group
1229             .group
1230             .propose_add(bob_kp.clone(), vec![])
1231             .await
1232             .unwrap();
1233 
1234         let output = group.group.commit(Vec::new()).await.unwrap();
1235         let welcomes = output.welcome_messages;
1236 
1237         let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1238 
1239         for (client, kp) in [(alice, alice_kp), (bob, bob_kp)] {
1240             let kp_ref = kp.key_package_reference(&cs).await.unwrap().unwrap();
1241 
1242             let welcome = welcomes
1243                 .iter()
1244                 .find(|w| w.welcome_key_package_references().contains(&&kp_ref))
1245                 .unwrap();
1246 
1247             client.join_group(None, welcome).await.unwrap();
1248 
1249             assert_eq!(welcome.clone().into_welcome().unwrap().secrets.len(), 1);
1250         }
1251     }
1252 
1253     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
commit_can_change_credential()1254     async fn commit_can_change_credential() {
1255         let cs = TEST_CIPHER_SUITE;
1256         let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, cs, 3).await;
1257         let (identity, secret_key) = get_test_signing_identity(cs, b"member").await;
1258 
1259         let commit_output = groups[0]
1260             .group
1261             .commit_builder()
1262             .set_new_signing_identity(secret_key, identity.clone())
1263             .build()
1264             .await
1265             .unwrap();
1266 
1267         // Check that the credential was updated by in the committer's state.
1268         groups[0].process_pending_commit().await.unwrap();
1269         let new_member = groups[0].group.roster().member_with_index(0).unwrap();
1270 
1271         assert_eq!(
1272             new_member.signing_identity.credential,
1273             get_test_basic_credential(b"member".to_vec())
1274         );
1275 
1276         assert_eq!(
1277             new_member.signing_identity.signature_key,
1278             identity.signature_key
1279         );
1280 
1281         // Check that the credential was updated in another member's state.
1282         groups[1]
1283             .process_message(commit_output.commit_message)
1284             .await
1285             .unwrap();
1286 
1287         let new_member = groups[1].group.roster().member_with_index(0).unwrap();
1288 
1289         assert_eq!(
1290             new_member.signing_identity.credential,
1291             get_test_basic_credential(b"member".to_vec())
1292         );
1293 
1294         assert_eq!(
1295             new_member.signing_identity.signature_key,
1296             identity.signature_key
1297         );
1298     }
1299 
1300     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
commit_includes_tree_if_no_ratchet_tree_ext()1301     async fn commit_includes_tree_if_no_ratchet_tree_ext() {
1302         let mut group = test_group_custom(
1303             TEST_PROTOCOL_VERSION,
1304             TEST_CIPHER_SUITE,
1305             Default::default(),
1306             None,
1307             Some(CommitOptions::new().with_ratchet_tree_extension(false)),
1308         )
1309         .await
1310         .group;
1311 
1312         let commit = group.commit(vec![]).await.unwrap();
1313 
1314         group.apply_pending_commit().await.unwrap();
1315 
1316         let new_tree = group.export_tree();
1317 
1318         assert_eq!(new_tree, commit.ratchet_tree.unwrap())
1319     }
1320 
1321     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
commit_does_not_include_tree_if_ratchet_tree_ext()1322     async fn commit_does_not_include_tree_if_ratchet_tree_ext() {
1323         let mut group = test_group_custom(
1324             TEST_PROTOCOL_VERSION,
1325             TEST_CIPHER_SUITE,
1326             Default::default(),
1327             None,
1328             Some(CommitOptions::new().with_ratchet_tree_extension(true)),
1329         )
1330         .await
1331         .group;
1332 
1333         let commit = group.commit(vec![]).await.unwrap();
1334 
1335         assert!(commit.ratchet_tree.is_none());
1336     }
1337 
1338     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
commit_includes_external_commit_group_info_if_requested()1339     async fn commit_includes_external_commit_group_info_if_requested() {
1340         let mut group = test_group_custom(
1341             TEST_PROTOCOL_VERSION,
1342             TEST_CIPHER_SUITE,
1343             Default::default(),
1344             None,
1345             Some(
1346                 CommitOptions::new()
1347                     .with_allow_external_commit(true)
1348                     .with_ratchet_tree_extension(false),
1349             ),
1350         )
1351         .await
1352         .group;
1353 
1354         let commit = group.commit(vec![]).await.unwrap();
1355 
1356         let info = commit
1357             .external_commit_group_info
1358             .unwrap()
1359             .into_group_info()
1360             .unwrap();
1361 
1362         assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1363         assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1364     }
1365 
1366     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
commit_includes_external_commit_and_tree_if_requested()1367     async fn commit_includes_external_commit_and_tree_if_requested() {
1368         let mut group = test_group_custom(
1369             TEST_PROTOCOL_VERSION,
1370             TEST_CIPHER_SUITE,
1371             Default::default(),
1372             None,
1373             Some(
1374                 CommitOptions::new()
1375                     .with_allow_external_commit(true)
1376                     .with_ratchet_tree_extension(true),
1377             ),
1378         )
1379         .await
1380         .group;
1381 
1382         let commit = group.commit(vec![]).await.unwrap();
1383 
1384         let info = commit
1385             .external_commit_group_info
1386             .unwrap()
1387             .into_group_info()
1388             .unwrap();
1389 
1390         assert!(info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1391         assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1392     }
1393 
1394     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
commit_does_not_include_external_commit_group_info_if_not_requested()1395     async fn commit_does_not_include_external_commit_group_info_if_not_requested() {
1396         let mut group = test_group_custom(
1397             TEST_PROTOCOL_VERSION,
1398             TEST_CIPHER_SUITE,
1399             Default::default(),
1400             None,
1401             Some(CommitOptions::new().with_allow_external_commit(false)),
1402         )
1403         .await
1404         .group;
1405 
1406         let commit = group.commit(vec![]).await.unwrap();
1407 
1408         assert!(commit.external_commit_group_info.is_none());
1409     }
1410 
1411     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
member_identity_is_validated_against_new_extensions()1412     async fn member_identity_is_validated_against_new_extensions() {
1413         let alice = client_with_test_extension(b"alice").await;
1414         let mut alice = alice.create_group(ExtensionList::new()).await.unwrap();
1415 
1416         let bob = client_with_test_extension(b"bob").await;
1417         let bob_kp = bob.generate_key_package_message().await.unwrap();
1418 
1419         let mut extension_list = ExtensionList::new();
1420         let extension = TestExtension { foo: b'a' };
1421         extension_list.set_from(extension).unwrap();
1422 
1423         let res = alice
1424             .commit_builder()
1425             .add_member(bob_kp)
1426             .unwrap()
1427             .set_group_context_ext(extension_list.clone())
1428             .unwrap()
1429             .build()
1430             .await;
1431 
1432         assert!(res.is_err());
1433 
1434         let alex = client_with_test_extension(b"alex").await;
1435 
1436         alice
1437             .commit_builder()
1438             .add_member(alex.generate_key_package_message().await.unwrap())
1439             .unwrap()
1440             .set_group_context_ext(extension_list.clone())
1441             .unwrap()
1442             .build()
1443             .await
1444             .unwrap();
1445     }
1446 
1447     #[cfg(feature = "by_ref_proposal")]
1448     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
server_identity_is_validated_against_new_extensions()1449     async fn server_identity_is_validated_against_new_extensions() {
1450         let alice = client_with_test_extension(b"alice").await;
1451         let mut alice = alice.create_group(ExtensionList::new()).await.unwrap();
1452 
1453         let mut extension_list = ExtensionList::new();
1454         let extension = TestExtension { foo: b'a' };
1455         extension_list.set_from(extension).unwrap();
1456 
1457         let (alex_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"alex").await;
1458 
1459         let mut alex_extensions = extension_list.clone();
1460 
1461         alex_extensions
1462             .set_from(ExternalSendersExt {
1463                 allowed_senders: vec![alex_server],
1464             })
1465             .unwrap();
1466 
1467         let res = alice
1468             .commit_builder()
1469             .set_group_context_ext(alex_extensions)
1470             .unwrap()
1471             .build()
1472             .await;
1473 
1474         assert!(res.is_err());
1475 
1476         let (bob_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
1477 
1478         let mut bob_extensions = extension_list;
1479 
1480         bob_extensions
1481             .set_from(ExternalSendersExt {
1482                 allowed_senders: vec![bob_server],
1483             })
1484             .unwrap();
1485 
1486         alice
1487             .commit_builder()
1488             .set_group_context_ext(bob_extensions)
1489             .unwrap()
1490             .build()
1491             .await
1492             .unwrap();
1493     }
1494 
1495     #[derive(Debug, Clone)]
1496     struct IdentityProviderWithExtension(BasicIdentityProvider);
1497 
1498     #[derive(Clone, Debug)]
1499     #[cfg_attr(feature = "std", derive(thiserror::Error))]
1500     #[cfg_attr(feature = "std", error("test error"))]
1501     struct IdentityProviderWithExtensionError {}
1502 
1503     impl IntoAnyError for IdentityProviderWithExtensionError {
1504         #[cfg(feature = "std")]
into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self>1505         fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
1506             Ok(self.into())
1507         }
1508     }
1509 
1510     impl IdentityProviderWithExtension {
1511         // True if the identity starts with the character `foo` from `TestExtension` or if `TestExtension`
1512         // is not set.
1513         #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
starts_with_foo( &self, identity: &SigningIdentity, _timestamp: Option<MlsTime>, extensions: Option<&ExtensionList>, ) -> bool1514         async fn starts_with_foo(
1515             &self,
1516             identity: &SigningIdentity,
1517             _timestamp: Option<MlsTime>,
1518             extensions: Option<&ExtensionList>,
1519         ) -> bool {
1520             if let Some(extensions) = extensions {
1521                 if let Some(ext) = extensions.get_as::<TestExtension>().unwrap() {
1522                     self.identity(identity, extensions).await.unwrap()[0] == ext.foo
1523                 } else {
1524                     true
1525                 }
1526             } else {
1527                 true
1528             }
1529         }
1530     }
1531 
1532     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1533     #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
1534     impl IdentityProvider for IdentityProviderWithExtension {
1535         type Error = IdentityProviderWithExtensionError;
1536 
validate_member( &self, identity: &SigningIdentity, timestamp: Option<MlsTime>, extensions: Option<&ExtensionList>, ) -> Result<(), Self::Error>1537         async fn validate_member(
1538             &self,
1539             identity: &SigningIdentity,
1540             timestamp: Option<MlsTime>,
1541             extensions: Option<&ExtensionList>,
1542         ) -> Result<(), Self::Error> {
1543             self.starts_with_foo(identity, timestamp, extensions)
1544                 .await
1545                 .then_some(())
1546                 .ok_or(IdentityProviderWithExtensionError {})
1547         }
1548 
validate_external_sender( &self, identity: &SigningIdentity, timestamp: Option<MlsTime>, extensions: Option<&ExtensionList>, ) -> Result<(), Self::Error>1549         async fn validate_external_sender(
1550             &self,
1551             identity: &SigningIdentity,
1552             timestamp: Option<MlsTime>,
1553             extensions: Option<&ExtensionList>,
1554         ) -> Result<(), Self::Error> {
1555             (!self.starts_with_foo(identity, timestamp, extensions).await)
1556                 .then_some(())
1557                 .ok_or(IdentityProviderWithExtensionError {})
1558         }
1559 
identity( &self, signing_identity: &SigningIdentity, extensions: &ExtensionList, ) -> Result<Vec<u8>, Self::Error>1560         async fn identity(
1561             &self,
1562             signing_identity: &SigningIdentity,
1563             extensions: &ExtensionList,
1564         ) -> Result<Vec<u8>, Self::Error> {
1565             self.0
1566                 .identity(signing_identity, extensions)
1567                 .await
1568                 .map_err(|_| IdentityProviderWithExtensionError {})
1569         }
1570 
valid_successor( &self, _predecessor: &SigningIdentity, _successor: &SigningIdentity, _extensions: &ExtensionList, ) -> Result<bool, Self::Error>1571         async fn valid_successor(
1572             &self,
1573             _predecessor: &SigningIdentity,
1574             _successor: &SigningIdentity,
1575             _extensions: &ExtensionList,
1576         ) -> Result<bool, Self::Error> {
1577             Ok(true)
1578         }
1579 
supported_types(&self) -> Vec<CredentialType>1580         fn supported_types(&self) -> Vec<CredentialType> {
1581             self.0.supported_types()
1582         }
1583     }
1584 
1585     type ExtensionClientConfig = WithIdentityProvider<
1586         IdentityProviderWithExtension,
1587         WithCryptoProvider<TestCryptoProvider, BaseConfig>,
1588     >;
1589 
1590     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
client_with_test_extension(name: &[u8]) -> Client<ExtensionClientConfig>1591     async fn client_with_test_extension(name: &[u8]) -> Client<ExtensionClientConfig> {
1592         let (identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, name).await;
1593 
1594         ClientBuilder::new()
1595             .crypto_provider(TestCryptoProvider::new())
1596             .extension_types(vec![TEST_EXTENSION_TYPE.into()])
1597             .identity_provider(IdentityProviderWithExtension(BasicIdentityProvider::new()))
1598             .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
1599             .build()
1600     }
1601 }
1602