1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec::Vec;
6 
7 use mls_rs_core::{
8     crypto::{CipherSuite, SignatureSecretKey},
9     extension::ExtensionList,
10     identity::SigningIdentity,
11     protocol_version::ProtocolVersion,
12 };
13 
14 use crate::{client::MlsError, Client, Group, MlsMessage};
15 
16 use super::{
17     proposal::ReInitProposal, ClientConfig, ExportedTree, JustPreSharedKeyID, MessageProcessor,
18     NewMemberInfo, PreSharedKeyID, PskGroupId, PskSecretInput, ResumptionPSKUsage, ResumptionPsk,
19 };
20 
21 struct ResumptionGroupParameters<'a> {
22     group_id: &'a [u8],
23     cipher_suite: CipherSuite,
24     version: ProtocolVersion,
25     extensions: &'a ExtensionList,
26 }
27 
28 pub struct ReinitClient<C: ClientConfig + Clone> {
29     client: Client<C>,
30     reinit: ReInitProposal,
31     psk_input: PskSecretInput,
32 }
33 
34 impl<C> Group<C>
35 where
36     C: ClientConfig + Clone,
37 {
38     /// Create a sub-group from a subset of the current group members.
39     ///
40     /// Membership within the resulting sub-group is indicated by providing a
41     /// key package that produces the same
42     /// [identity](crate::IdentityProvider::identity) value
43     /// as an existing group member. The identity value of each key package
44     /// is determined using the
45     /// [`IdentityProvider`](crate::IdentityProvider)
46     /// that is currently in use by this group instance.
47     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
branch( &self, sub_group_id: Vec<u8>, new_key_packages: Vec<MlsMessage>, ) -> Result<(Group<C>, Vec<MlsMessage>), MlsError>48     pub async fn branch(
49         &self,
50         sub_group_id: Vec<u8>,
51         new_key_packages: Vec<MlsMessage>,
52     ) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
53         let new_group_params = ResumptionGroupParameters {
54             group_id: &sub_group_id,
55             cipher_suite: self.cipher_suite(),
56             version: self.protocol_version(),
57             extensions: &self.group_state().context.extensions,
58         };
59 
60         resumption_create_group(
61             self.config.clone(),
62             new_key_packages,
63             &new_group_params,
64             // TODO investigate if it's worth updating your own signing identity here
65             self.current_member_signing_identity()?.clone(),
66             self.signer.clone(),
67             #[cfg(any(feature = "private_message", feature = "psk"))]
68             self.resumption_psk_input(ResumptionPSKUsage::Branch)?,
69         )
70         .await
71     }
72 
73     /// Join a subgroup that was created by [`Group::branch`].
74     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
join_subgroup( &self, welcome: &MlsMessage, tree_data: Option<ExportedTree<'_>>, ) -> Result<(Group<C>, NewMemberInfo), MlsError>75     pub async fn join_subgroup(
76         &self,
77         welcome: &MlsMessage,
78         tree_data: Option<ExportedTree<'_>>,
79     ) -> Result<(Group<C>, NewMemberInfo), MlsError> {
80         let expected_new_group_prams = ResumptionGroupParameters {
81             group_id: &[],
82             cipher_suite: self.cipher_suite(),
83             version: self.protocol_version(),
84             extensions: &self.group_state().context.extensions,
85         };
86 
87         resumption_join_group(
88             self.config.clone(),
89             self.signer.clone(),
90             welcome,
91             tree_data,
92             expected_new_group_prams,
93             false,
94             self.resumption_psk_input(ResumptionPSKUsage::Branch)?,
95         )
96         .await
97     }
98 
99     /// Generate a [`ReinitClient`] that can be used to create or join a new group
100     /// that is based on properties defined by a [`ReInitProposal`]
101     /// committed in a previously accepted commit. This is the only action available
102     /// after accepting such a commit. The old group can no longer be used according to the RFC.
103     ///
104     /// If the [`ReInitProposal`] changes the ciphersuite, then `new_signer`
105     /// and `new_signer_identity` must be set and match the new ciphersuite, as indicated by
106     /// [`pending_reinit_ciphersuite`](crate::group::StateUpdate::pending_reinit_ciphersuite)
107     /// of the [`StateUpdate`](crate::group::StateUpdate) outputted after processing the
108     /// commit to the reinit proposal. The value of [identity](crate::IdentityProvider::identity)
109     /// must be the same for `new_signing_identity` and the current identity in use by this
110     /// group instance.
get_reinit_client( self, new_signer: Option<SignatureSecretKey>, new_signing_identity: Option<SigningIdentity>, ) -> Result<ReinitClient<C>, MlsError>111     pub fn get_reinit_client(
112         self,
113         new_signer: Option<SignatureSecretKey>,
114         new_signing_identity: Option<SigningIdentity>,
115     ) -> Result<ReinitClient<C>, MlsError> {
116         let psk_input = self.resumption_psk_input(ResumptionPSKUsage::Reinit)?;
117 
118         let new_signing_identity = new_signing_identity
119             .map(Ok)
120             .unwrap_or_else(|| self.current_member_signing_identity().cloned())?;
121 
122         let reinit = self
123             .state
124             .pending_reinit
125             .ok_or(MlsError::PendingReInitNotFound)?;
126 
127         let new_signer = match new_signer {
128             Some(signer) => signer,
129             None => self.signer,
130         };
131 
132         let client = Client::new(
133             self.config,
134             Some(new_signer),
135             Some((new_signing_identity, reinit.new_cipher_suite())),
136             reinit.new_version(),
137         );
138 
139         Ok(ReinitClient {
140             client,
141             reinit,
142             psk_input,
143         })
144     }
145 
resumption_psk_input(&self, usage: ResumptionPSKUsage) -> Result<PskSecretInput, MlsError>146     fn resumption_psk_input(&self, usage: ResumptionPSKUsage) -> Result<PskSecretInput, MlsError> {
147         let psk = self.epoch_secrets.resumption_secret.clone();
148 
149         let id = JustPreSharedKeyID::Resumption(ResumptionPsk {
150             usage,
151             psk_group_id: PskGroupId(self.group_id().to_vec()),
152             psk_epoch: self.current_epoch(),
153         });
154 
155         let id = PreSharedKeyID::new(id, self.cipher_suite_provider())?;
156         Ok(PskSecretInput { id, psk })
157     }
158 }
159 
160 /// A [`Client`] that can be used to create or join a new group
161 /// that is based on properties defined by a [`ReInitProposal`]
162 /// committed in a previously accepted commit.
163 impl<C: ClientConfig + Clone> ReinitClient<C> {
164     /// Generate a key package for the new group. The key package can
165     /// be used in [`ReinitClient::commit`].
166     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
generate_key_package(&self) -> Result<MlsMessage, MlsError>167     pub async fn generate_key_package(&self) -> Result<MlsMessage, MlsError> {
168         self.client.generate_key_package_message().await
169     }
170 
171     /// Create the new group using new key packages of all group members, possibly
172     /// generated by [`ReinitClient::generate_key_package`].
173     ///
174     /// # Warning
175     ///
176     /// This function will fail if the number of members in the reinitialized
177     /// group is not the same as the prior group roster.
178     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
commit( self, new_key_packages: Vec<MlsMessage>, ) -> Result<(Group<C>, Vec<MlsMessage>), MlsError>179     pub async fn commit(
180         self,
181         new_key_packages: Vec<MlsMessage>,
182     ) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
183         let new_group_params = ResumptionGroupParameters {
184             group_id: self.reinit.group_id(),
185             cipher_suite: self.reinit.new_cipher_suite(),
186             version: self.reinit.new_version(),
187             extensions: self.reinit.new_group_context_extensions(),
188         };
189 
190         resumption_create_group(
191             self.client.config.clone(),
192             new_key_packages,
193             &new_group_params,
194             // These private fields are created with `Some(x)` by `get_reinit_client`
195             self.client.signing_identity.unwrap().0,
196             self.client.signer.unwrap(),
197             #[cfg(any(feature = "private_message", feature = "psk"))]
198             self.psk_input,
199         )
200         .await
201     }
202 
203     /// Join a reinitialized group that was created by [`ReinitClient::commit`].
204     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
join( self, welcome: &MlsMessage, tree_data: Option<ExportedTree<'_>>, ) -> Result<(Group<C>, NewMemberInfo), MlsError>205     pub async fn join(
206         self,
207         welcome: &MlsMessage,
208         tree_data: Option<ExportedTree<'_>>,
209     ) -> Result<(Group<C>, NewMemberInfo), MlsError> {
210         let reinit = self.reinit;
211 
212         let expected_group_params = ResumptionGroupParameters {
213             group_id: reinit.group_id(),
214             cipher_suite: reinit.new_cipher_suite(),
215             version: reinit.new_version(),
216             extensions: reinit.new_group_context_extensions(),
217         };
218 
219         resumption_join_group(
220             self.client.config,
221             // This private field is created with `Some(x)` by `get_reinit_client`
222             self.client.signer.unwrap(),
223             welcome,
224             tree_data,
225             expected_group_params,
226             true,
227             self.psk_input,
228         )
229         .await
230     }
231 }
232 
233 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
resumption_create_group<C: ClientConfig + Clone>( config: C, new_key_packages: Vec<MlsMessage>, new_group_params: &ResumptionGroupParameters<'_>, signing_identity: SigningIdentity, signer: SignatureSecretKey, psk_input: PskSecretInput, ) -> Result<(Group<C>, Vec<MlsMessage>), MlsError>234 async fn resumption_create_group<C: ClientConfig + Clone>(
235     config: C,
236     new_key_packages: Vec<MlsMessage>,
237     new_group_params: &ResumptionGroupParameters<'_>,
238     signing_identity: SigningIdentity,
239     signer: SignatureSecretKey,
240     psk_input: PskSecretInput,
241 ) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
242     // Create a new group with new parameters
243     let mut group = Group::new(
244         config,
245         Some(new_group_params.group_id.to_vec()),
246         new_group_params.cipher_suite,
247         new_group_params.version,
248         signing_identity,
249         new_group_params.extensions.clone(),
250         signer,
251     )
252     .await?;
253 
254     // Install the resumption psk in the new group
255     group.previous_psk = Some(psk_input);
256 
257     // Create a commit that adds new key packages and uses the resumption PSK
258     let mut commit = group.commit_builder();
259 
260     for kp in new_key_packages.into_iter() {
261         commit = commit.add_member(kp)?;
262     }
263 
264     let commit = commit.build().await?;
265     group.apply_pending_commit().await?;
266 
267     // Uninstall the resumption psk on success (in case of failure, the new group is discarded anyway)
268     group.previous_psk = None;
269 
270     Ok((group, commit.welcome_messages))
271 }
272 
273 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
resumption_join_group<C: ClientConfig + Clone>( config: C, signer: SignatureSecretKey, welcome: &MlsMessage, tree_data: Option<ExportedTree<'_>>, expected_new_group_params: ResumptionGroupParameters<'_>, verify_group_id: bool, psk_input: PskSecretInput, ) -> Result<(Group<C>, NewMemberInfo), MlsError>274 async fn resumption_join_group<C: ClientConfig + Clone>(
275     config: C,
276     signer: SignatureSecretKey,
277     welcome: &MlsMessage,
278     tree_data: Option<ExportedTree<'_>>,
279     expected_new_group_params: ResumptionGroupParameters<'_>,
280     verify_group_id: bool,
281     psk_input: PskSecretInput,
282 ) -> Result<(Group<C>, NewMemberInfo), MlsError> {
283     let psk_input = Some(psk_input);
284 
285     let (group, new_member_info) =
286         Group::<C>::from_welcome_message(welcome, tree_data, config, signer, psk_input).await?;
287 
288     if group.protocol_version() != expected_new_group_params.version {
289         Err(MlsError::ProtocolVersionMismatch)
290     } else if group.cipher_suite() != expected_new_group_params.cipher_suite {
291         Err(MlsError::CipherSuiteMismatch)
292     } else if verify_group_id && group.group_id() != expected_new_group_params.group_id {
293         Err(MlsError::GroupIdMismatch)
294     } else if &group.group_state().context.extensions != expected_new_group_params.extensions {
295         Err(MlsError::ReInitExtensionsMismatch)
296     } else {
297         Ok((group, new_member_info))
298     }
299 }
300