1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use crate::{
6     client::MlsError,
7     group::{
8         proposal::ReInitProposal,
9         proposal_filter::{ProposalBundle, ProposalInfo},
10         AddProposal, ProposalType, RemoveProposal, Sender, UpdateProposal,
11     },
12     iter::wrap_iter,
13     protocol_version::ProtocolVersion,
14     time::MlsTime,
15     tree_kem::{
16         leaf_node_validator::{LeafNodeValidator, ValidationContext},
17         node::LeafIndex,
18         TreeKemPublic,
19     },
20     CipherSuiteProvider, ExtensionList,
21 };
22 
23 use super::filtering_common::{filter_out_invalid_psks, ApplyProposalsOutput, ProposalApplier};
24 
25 #[cfg(feature = "by_ref_proposal")]
26 use crate::extension::ExternalSendersExt;
27 
28 use alloc::vec::Vec;
29 use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider, psk::PreSharedKeyStorage};
30 
31 #[cfg(any(
32     feature = "custom_proposal",
33     not(any(mls_build_async, feature = "rayon"))
34 ))]
35 use itertools::Itertools;
36 
37 use crate::group::ExternalInit;
38 
39 #[cfg(feature = "psk")]
40 use crate::group::proposal::PreSharedKeyProposal;
41 
42 #[cfg(all(not(mls_build_async), feature = "rayon"))]
43 use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
44 
45 #[cfg(mls_build_async)]
46 use futures::{StreamExt, TryStreamExt};
47 
48 impl<'a, C, P, CSP> ProposalApplier<'a, C, P, CSP>
49 where
50     C: IdentityProvider,
51     P: PreSharedKeyStorage,
52     CSP: CipherSuiteProvider,
53 {
54     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
apply_proposals_from_member( &self, strategy: FilterStrategy, commit_sender: LeafIndex, proposals: ProposalBundle, commit_time: Option<MlsTime>, ) -> Result<ApplyProposalsOutput, MlsError>55     pub(super) async fn apply_proposals_from_member(
56         &self,
57         strategy: FilterStrategy,
58         commit_sender: LeafIndex,
59         proposals: ProposalBundle,
60         commit_time: Option<MlsTime>,
61     ) -> Result<ApplyProposalsOutput, MlsError> {
62         let proposals = filter_out_invalid_proposers(strategy, proposals)?;
63 
64         let mut proposals: ProposalBundle =
65             filter_out_update_for_committer(strategy, commit_sender, proposals)?;
66 
67         // We ignore the strategy here because the check above ensures all updates are from members
68         proposals.update_senders = proposals
69             .updates
70             .iter()
71             .map(leaf_index_of_update_sender)
72             .collect::<Result<_, _>>()?;
73 
74         let mut proposals = filter_out_removal_of_committer(strategy, commit_sender, proposals)?;
75 
76         filter_out_invalid_psks(
77             strategy,
78             self.cipher_suite_provider,
79             &mut proposals,
80             self.psk_storage,
81         )
82         .await?;
83 
84         #[cfg(feature = "by_ref_proposal")]
85         let proposals = filter_out_invalid_group_extensions(
86             strategy,
87             proposals,
88             self.identity_provider,
89             commit_time,
90         )
91         .await?;
92 
93         let proposals = filter_out_extra_group_context_extensions(strategy, proposals)?;
94         let proposals = filter_out_invalid_reinit(strategy, proposals, self.protocol_version)?;
95         let proposals = filter_out_reinit_if_other_proposals(strategy.is_ignore(), proposals)?;
96 
97         let proposals = filter_out_external_init(strategy, proposals)?;
98 
99         self.apply_proposal_changes(strategy, proposals, commit_time)
100             .await
101     }
102 
103     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
apply_proposal_changes( &self, strategy: FilterStrategy, proposals: ProposalBundle, commit_time: Option<MlsTime>, ) -> Result<ApplyProposalsOutput, MlsError>104     pub(super) async fn apply_proposal_changes(
105         &self,
106         strategy: FilterStrategy,
107         proposals: ProposalBundle,
108         commit_time: Option<MlsTime>,
109     ) -> Result<ApplyProposalsOutput, MlsError> {
110         match proposals.group_context_extensions_proposal().cloned() {
111             Some(p) => {
112                 self.apply_proposals_with_new_capabilities(strategy, proposals, p, commit_time)
113                     .await
114             }
115             None => {
116                 self.apply_tree_changes(
117                     strategy,
118                     proposals,
119                     self.original_group_extensions,
120                     commit_time,
121                 )
122                 .await
123             }
124         }
125     }
126 
127     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
apply_tree_changes( &self, strategy: FilterStrategy, proposals: ProposalBundle, group_extensions_in_use: &ExtensionList, commit_time: Option<MlsTime>, ) -> Result<ApplyProposalsOutput, MlsError>128     pub(super) async fn apply_tree_changes(
129         &self,
130         strategy: FilterStrategy,
131         proposals: ProposalBundle,
132         group_extensions_in_use: &ExtensionList,
133         commit_time: Option<MlsTime>,
134     ) -> Result<ApplyProposalsOutput, MlsError> {
135         let mut applied_proposals = self
136             .validate_new_nodes(strategy, proposals, group_extensions_in_use, commit_time)
137             .await?;
138 
139         let mut new_tree = self.original_tree.clone();
140 
141         let added = new_tree
142             .batch_edit(
143                 &mut applied_proposals,
144                 group_extensions_in_use,
145                 self.identity_provider,
146                 self.cipher_suite_provider,
147                 strategy.is_ignore(),
148             )
149             .await?;
150 
151         let new_context_extensions = applied_proposals
152             .group_context_extensions_proposal()
153             .map(|gce| gce.proposal.clone());
154 
155         Ok(ApplyProposalsOutput {
156             applied_proposals,
157             new_tree,
158             indexes_of_added_kpkgs: added,
159             external_init_index: None,
160             new_context_extensions,
161         })
162     }
163 
164     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
validate_new_nodes( &self, strategy: FilterStrategy, mut proposals: ProposalBundle, group_extensions_in_use: &ExtensionList, commit_time: Option<MlsTime>, ) -> Result<ProposalBundle, MlsError>165     async fn validate_new_nodes(
166         &self,
167         strategy: FilterStrategy,
168         mut proposals: ProposalBundle,
169         group_extensions_in_use: &ExtensionList,
170         commit_time: Option<MlsTime>,
171     ) -> Result<ProposalBundle, MlsError> {
172         let leaf_node_validator = &LeafNodeValidator::new(
173             self.cipher_suite_provider,
174             self.identity_provider,
175             Some(group_extensions_in_use),
176         );
177 
178         let bad_indices: Vec<_> = wrap_iter(proposals.update_proposals())
179             .zip(wrap_iter(proposals.update_proposal_senders()))
180             .enumerate()
181             .filter_map(|(i, (p, &sender_index))| async move {
182                 let res = {
183                     let leaf = &p.proposal.leaf_node;
184 
185                     let res = leaf_node_validator
186                         .check_if_valid(
187                             leaf,
188                             ValidationContext::Update((self.group_id, *sender_index, commit_time)),
189                         )
190                         .await;
191 
192                     let old_leaf = match self.original_tree.get_leaf_node(sender_index) {
193                         Ok(leaf) => leaf,
194                         Err(e) => return Some(Err(e)),
195                     };
196 
197                     let valid_successor = self
198                         .identity_provider
199                         .valid_successor(
200                             &old_leaf.signing_identity,
201                             &leaf.signing_identity,
202                             group_extensions_in_use,
203                         )
204                         .await
205                         .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))
206                         .and_then(|valid| valid.then_some(()).ok_or(MlsError::InvalidSuccessor));
207 
208                     res.and(valid_successor)
209                 };
210 
211                 apply_strategy(strategy, p.is_by_reference(), res)
212                     .map(|b| (!b).then_some(i))
213                     .transpose()
214             })
215             .try_collect()
216             .await?;
217 
218         bad_indices.into_iter().rev().for_each(|i| {
219             proposals.remove::<UpdateProposal>(i);
220             proposals.update_senders.remove(i);
221         });
222 
223         let bad_indices: Vec<_> = wrap_iter(proposals.add_proposals())
224             .enumerate()
225             .filter_map(|(i, p)| async move {
226                 let res = self
227                     .validate_new_node(leaf_node_validator, &p.proposal.key_package, commit_time)
228                     .await;
229 
230                 apply_strategy(strategy, p.is_by_reference(), res)
231                     .map(|b| (!b).then_some(i))
232                     .transpose()
233             })
234             .try_collect()
235             .await?;
236 
237         bad_indices
238             .into_iter()
239             .rev()
240             .for_each(|i| proposals.remove::<AddProposal>(i));
241 
242         Ok(proposals)
243     }
244 }
245 
246 #[derive(Clone, Copy, Debug)]
247 pub enum FilterStrategy {
248     IgnoreByRef,
249     IgnoreNone,
250 }
251 
252 impl FilterStrategy {
ignore(self, by_ref: bool) -> bool253     pub(super) fn ignore(self, by_ref: bool) -> bool {
254         match self {
255             FilterStrategy::IgnoreByRef => by_ref,
256             FilterStrategy::IgnoreNone => false,
257         }
258     }
259 
is_ignore(self) -> bool260     fn is_ignore(self) -> bool {
261         match self {
262             FilterStrategy::IgnoreByRef => true,
263             FilterStrategy::IgnoreNone => false,
264         }
265     }
266 }
267 
apply_strategy( strategy: FilterStrategy, by_ref: bool, r: Result<(), MlsError>, ) -> Result<bool, MlsError>268 pub(crate) fn apply_strategy(
269     strategy: FilterStrategy,
270     by_ref: bool,
271     r: Result<(), MlsError>,
272 ) -> Result<bool, MlsError> {
273     r.map(|_| true)
274         .or_else(|error| strategy.ignore(by_ref).then_some(false).ok_or(error))
275 }
276 
filter_out_update_for_committer( strategy: FilterStrategy, commit_sender: LeafIndex, mut proposals: ProposalBundle, ) -> Result<ProposalBundle, MlsError>277 fn filter_out_update_for_committer(
278     strategy: FilterStrategy,
279     commit_sender: LeafIndex,
280     mut proposals: ProposalBundle,
281 ) -> Result<ProposalBundle, MlsError> {
282     proposals.retain_by_type::<UpdateProposal, _, _>(|p| {
283         apply_strategy(
284             strategy,
285             p.is_by_reference(),
286             (p.sender != Sender::Member(*commit_sender))
287                 .then_some(())
288                 .ok_or(MlsError::InvalidCommitSelfUpdate),
289         )
290     })?;
291     Ok(proposals)
292 }
293 
filter_out_removal_of_committer( strategy: FilterStrategy, commit_sender: LeafIndex, mut proposals: ProposalBundle, ) -> Result<ProposalBundle, MlsError>294 fn filter_out_removal_of_committer(
295     strategy: FilterStrategy,
296     commit_sender: LeafIndex,
297     mut proposals: ProposalBundle,
298 ) -> Result<ProposalBundle, MlsError> {
299     proposals.retain_by_type::<RemoveProposal, _, _>(|p| {
300         apply_strategy(
301             strategy,
302             p.is_by_reference(),
303             (p.proposal.to_remove != commit_sender)
304                 .then_some(())
305                 .ok_or(MlsError::CommitterSelfRemoval),
306         )
307     })?;
308     Ok(proposals)
309 }
310 
311 #[cfg(feature = "by_ref_proposal")]
312 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
filter_out_invalid_group_extensions<C>( strategy: FilterStrategy, mut proposals: ProposalBundle, identity_provider: &C, commit_time: Option<MlsTime>, ) -> Result<ProposalBundle, MlsError> where C: IdentityProvider,313 async fn filter_out_invalid_group_extensions<C>(
314     strategy: FilterStrategy,
315     mut proposals: ProposalBundle,
316     identity_provider: &C,
317     commit_time: Option<MlsTime>,
318 ) -> Result<ProposalBundle, MlsError>
319 where
320     C: IdentityProvider,
321 {
322     let mut bad_indices = Vec::new();
323 
324     for (i, p) in proposals.by_type::<ExtensionList>().enumerate() {
325         let ext = p.proposal.get_as::<ExternalSendersExt>();
326 
327         let res = match ext {
328             Ok(None) => Ok(()),
329             Ok(Some(extension)) => extension
330                 .verify_all(identity_provider, commit_time, &p.proposal)
331                 .await
332                 .map_err(|e| MlsError::IdentityProviderError(e.into_any_error())),
333             Err(e) => Err(MlsError::from(e)),
334         };
335 
336         if !apply_strategy(strategy, p.is_by_reference(), res)? {
337             bad_indices.push(i);
338         }
339     }
340 
341     bad_indices
342         .into_iter()
343         .rev()
344         .for_each(|i| proposals.remove::<ExtensionList>(i));
345 
346     Ok(proposals)
347 }
348 
filter_out_extra_group_context_extensions( strategy: FilterStrategy, mut proposals: ProposalBundle, ) -> Result<ProposalBundle, MlsError>349 fn filter_out_extra_group_context_extensions(
350     strategy: FilterStrategy,
351     mut proposals: ProposalBundle,
352 ) -> Result<ProposalBundle, MlsError> {
353     let mut found = false;
354 
355     proposals.retain_by_type::<ExtensionList, _, _>(|p| {
356         apply_strategy(
357             strategy,
358             p.is_by_reference(),
359             (!core::mem::replace(&mut found, true))
360                 .then_some(())
361                 .ok_or(MlsError::MoreThanOneGroupContextExtensionsProposal),
362         )
363     })?;
364 
365     Ok(proposals)
366 }
367 
filter_out_invalid_reinit( strategy: FilterStrategy, mut proposals: ProposalBundle, protocol_version: ProtocolVersion, ) -> Result<ProposalBundle, MlsError>368 fn filter_out_invalid_reinit(
369     strategy: FilterStrategy,
370     mut proposals: ProposalBundle,
371     protocol_version: ProtocolVersion,
372 ) -> Result<ProposalBundle, MlsError> {
373     proposals.retain_by_type::<ReInitProposal, _, _>(|p| {
374         apply_strategy(
375             strategy,
376             p.is_by_reference(),
377             (p.proposal.version >= protocol_version)
378                 .then_some(())
379                 .ok_or(MlsError::InvalidProtocolVersionInReInit),
380         )
381     })?;
382 
383     Ok(proposals)
384 }
385 
filter_out_reinit_if_other_proposals( filter: bool, mut proposals: ProposalBundle, ) -> Result<ProposalBundle, MlsError>386 fn filter_out_reinit_if_other_proposals(
387     filter: bool,
388     mut proposals: ProposalBundle,
389 ) -> Result<ProposalBundle, MlsError> {
390     let proposal_count = proposals.length();
391 
392     let has_reinit_and_other_proposal =
393         !proposals.reinit_proposals().is_empty() && proposal_count != 1;
394 
395     if has_reinit_and_other_proposal {
396         let any_by_val = proposals.reinit_proposals().iter().any(|p| p.is_by_value());
397 
398         if any_by_val || !filter {
399             return Err(MlsError::OtherProposalWithReInit);
400         }
401 
402         let has_other_proposal_type = proposal_count > proposals.reinit_proposals().len();
403 
404         if has_other_proposal_type {
405             proposals.reinitializations = Vec::new();
406         } else {
407             proposals.reinitializations.truncate(1);
408         }
409     }
410 
411     Ok(proposals)
412 }
413 
filter_out_external_init( strategy: FilterStrategy, mut proposals: ProposalBundle, ) -> Result<ProposalBundle, MlsError>414 fn filter_out_external_init(
415     strategy: FilterStrategy,
416     mut proposals: ProposalBundle,
417 ) -> Result<ProposalBundle, MlsError> {
418     proposals.retain_by_type::<ExternalInit, _, _>(|p| {
419         apply_strategy(
420             strategy,
421             p.is_by_reference(),
422             Err(MlsError::InvalidProposalTypeForSender),
423         )
424     })?;
425 
426     Ok(proposals)
427 }
428 
proposer_can_propose( proposer: Sender, proposal_type: ProposalType, by_ref: bool, ) -> Result<(), MlsError>429 pub(crate) fn proposer_can_propose(
430     proposer: Sender,
431     proposal_type: ProposalType,
432     by_ref: bool,
433 ) -> Result<(), MlsError> {
434     let can_propose = match (proposer, by_ref) {
435         (Sender::Member(_), false) => matches!(
436             proposal_type,
437             ProposalType::ADD
438                 | ProposalType::REMOVE
439                 | ProposalType::PSK
440                 | ProposalType::RE_INIT
441                 | ProposalType::GROUP_CONTEXT_EXTENSIONS
442         ),
443         (Sender::Member(_), true) => matches!(
444             proposal_type,
445             ProposalType::ADD
446                 | ProposalType::UPDATE
447                 | ProposalType::REMOVE
448                 | ProposalType::PSK
449                 | ProposalType::RE_INIT
450                 | ProposalType::GROUP_CONTEXT_EXTENSIONS
451         ),
452         #[cfg(feature = "by_ref_proposal")]
453         (Sender::External(_), false) => false,
454         #[cfg(feature = "by_ref_proposal")]
455         (Sender::External(_), true) => matches!(
456             proposal_type,
457             ProposalType::ADD
458                 | ProposalType::REMOVE
459                 | ProposalType::RE_INIT
460                 | ProposalType::PSK
461                 | ProposalType::GROUP_CONTEXT_EXTENSIONS
462         ),
463         (Sender::NewMemberCommit, false) => matches!(
464             proposal_type,
465             ProposalType::REMOVE | ProposalType::PSK | ProposalType::EXTERNAL_INIT
466         ),
467         (Sender::NewMemberCommit, true) => false,
468         (Sender::NewMemberProposal, false) => false,
469         (Sender::NewMemberProposal, true) => matches!(proposal_type, ProposalType::ADD),
470     };
471 
472     can_propose
473         .then_some(())
474         .ok_or(MlsError::InvalidProposalTypeForSender)
475 }
476 
filter_out_invalid_proposers( strategy: FilterStrategy, mut proposals: ProposalBundle, ) -> Result<ProposalBundle, MlsError>477 pub(crate) fn filter_out_invalid_proposers(
478     strategy: FilterStrategy,
479     mut proposals: ProposalBundle,
480 ) -> Result<ProposalBundle, MlsError> {
481     for i in (0..proposals.add_proposals().len()).rev() {
482         let p = &proposals.add_proposals()[i];
483         let res = proposer_can_propose(p.sender, ProposalType::ADD, p.is_by_reference());
484 
485         if !apply_strategy(strategy, p.is_by_reference(), res)? {
486             proposals.remove::<AddProposal>(i);
487         }
488     }
489 
490     for i in (0..proposals.update_proposals().len()).rev() {
491         let p = &proposals.update_proposals()[i];
492         let res = proposer_can_propose(p.sender, ProposalType::UPDATE, p.is_by_reference());
493 
494         if !apply_strategy(strategy, p.is_by_reference(), res)? {
495             proposals.remove::<UpdateProposal>(i);
496             proposals.update_senders.remove(i);
497         }
498     }
499 
500     for i in (0..proposals.remove_proposals().len()).rev() {
501         let p = &proposals.remove_proposals()[i];
502         let res = proposer_can_propose(p.sender, ProposalType::REMOVE, p.is_by_reference());
503 
504         if !apply_strategy(strategy, p.is_by_reference(), res)? {
505             proposals.remove::<RemoveProposal>(i);
506         }
507     }
508 
509     #[cfg(feature = "psk")]
510     for i in (0..proposals.psk_proposals().len()).rev() {
511         let p = &proposals.psk_proposals()[i];
512         let res = proposer_can_propose(p.sender, ProposalType::PSK, p.is_by_reference());
513 
514         if !apply_strategy(strategy, p.is_by_reference(), res)? {
515             proposals.remove::<PreSharedKeyProposal>(i);
516         }
517     }
518 
519     for i in (0..proposals.reinit_proposals().len()).rev() {
520         let p = &proposals.reinit_proposals()[i];
521         let res = proposer_can_propose(p.sender, ProposalType::RE_INIT, p.is_by_reference());
522 
523         if !apply_strategy(strategy, p.is_by_reference(), res)? {
524             proposals.remove::<ReInitProposal>(i);
525         }
526     }
527 
528     for i in (0..proposals.external_init_proposals().len()).rev() {
529         let p = &proposals.external_init_proposals()[i];
530         let res = proposer_can_propose(p.sender, ProposalType::EXTERNAL_INIT, p.is_by_reference());
531 
532         if !apply_strategy(strategy, p.is_by_reference(), res)? {
533             proposals.remove::<ExternalInit>(i);
534         }
535     }
536 
537     for i in (0..proposals.group_context_ext_proposals().len()).rev() {
538         let p = &proposals.group_context_ext_proposals()[i];
539         let gce_type = ProposalType::GROUP_CONTEXT_EXTENSIONS;
540         let res = proposer_can_propose(p.sender, gce_type, p.is_by_reference());
541 
542         if !apply_strategy(strategy, p.is_by_reference(), res)? {
543             proposals.remove::<ExtensionList>(i);
544         }
545     }
546 
547     Ok(proposals)
548 }
549 
leaf_index_of_update_sender(p: &ProposalInfo<UpdateProposal>) -> Result<LeafIndex, MlsError>550 fn leaf_index_of_update_sender(p: &ProposalInfo<UpdateProposal>) -> Result<LeafIndex, MlsError> {
551     match p.sender {
552         Sender::Member(i) => Ok(LeafIndex(i)),
553         _ => Err(MlsError::InvalidProposalTypeForSender),
554     }
555 }
556 
557 #[cfg(feature = "custom_proposal")]
filter_out_unsupported_custom_proposals( proposals: &mut ProposalBundle, tree: &TreeKemPublic, strategy: FilterStrategy, ) -> Result<(), MlsError>558 pub(super) fn filter_out_unsupported_custom_proposals(
559     proposals: &mut ProposalBundle,
560     tree: &TreeKemPublic,
561     strategy: FilterStrategy,
562 ) -> Result<(), MlsError> {
563     let supported_types = proposals
564         .custom_proposal_types()
565         .filter(|t| tree.can_support_proposal(*t))
566         .collect_vec();
567 
568     proposals.retain_custom(|p| {
569         let proposal_type = p.proposal.proposal_type();
570 
571         apply_strategy(
572             strategy,
573             p.is_by_reference(),
574             supported_types
575                 .contains(&proposal_type)
576                 .then_some(())
577                 .ok_or(MlsError::UnsupportedCustomProposal(proposal_type)),
578         )
579     })
580 }
581