1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5 use crate::{
6 client::MlsError,
7 group::{
8 proposal::ReInitProposal,
9 proposal_filter::{ProposalBundle, ProposalInfo},
10 AddProposal, ProposalType, RemoveProposal, Sender, UpdateProposal,
11 },
12 iter::wrap_iter,
13 protocol_version::ProtocolVersion,
14 time::MlsTime,
15 tree_kem::{
16 leaf_node_validator::{LeafNodeValidator, ValidationContext},
17 node::LeafIndex,
18 TreeKemPublic,
19 },
20 CipherSuiteProvider, ExtensionList,
21 };
22
23 use super::filtering_common::{filter_out_invalid_psks, ApplyProposalsOutput, ProposalApplier};
24
25 #[cfg(feature = "by_ref_proposal")]
26 use crate::extension::ExternalSendersExt;
27
28 use alloc::vec::Vec;
29 use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider, psk::PreSharedKeyStorage};
30
31 #[cfg(any(
32 feature = "custom_proposal",
33 not(any(mls_build_async, feature = "rayon"))
34 ))]
35 use itertools::Itertools;
36
37 use crate::group::ExternalInit;
38
39 #[cfg(feature = "psk")]
40 use crate::group::proposal::PreSharedKeyProposal;
41
42 #[cfg(all(not(mls_build_async), feature = "rayon"))]
43 use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
44
45 #[cfg(mls_build_async)]
46 use futures::{StreamExt, TryStreamExt};
47
48 impl<'a, C, P, CSP> ProposalApplier<'a, C, P, CSP>
49 where
50 C: IdentityProvider,
51 P: PreSharedKeyStorage,
52 CSP: CipherSuiteProvider,
53 {
54 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
apply_proposals_from_member( &self, strategy: FilterStrategy, commit_sender: LeafIndex, proposals: ProposalBundle, commit_time: Option<MlsTime>, ) -> Result<ApplyProposalsOutput, MlsError>55 pub(super) async fn apply_proposals_from_member(
56 &self,
57 strategy: FilterStrategy,
58 commit_sender: LeafIndex,
59 proposals: ProposalBundle,
60 commit_time: Option<MlsTime>,
61 ) -> Result<ApplyProposalsOutput, MlsError> {
62 let proposals = filter_out_invalid_proposers(strategy, proposals)?;
63
64 let mut proposals: ProposalBundle =
65 filter_out_update_for_committer(strategy, commit_sender, proposals)?;
66
67 // We ignore the strategy here because the check above ensures all updates are from members
68 proposals.update_senders = proposals
69 .updates
70 .iter()
71 .map(leaf_index_of_update_sender)
72 .collect::<Result<_, _>>()?;
73
74 let mut proposals = filter_out_removal_of_committer(strategy, commit_sender, proposals)?;
75
76 filter_out_invalid_psks(
77 strategy,
78 self.cipher_suite_provider,
79 &mut proposals,
80 self.psk_storage,
81 )
82 .await?;
83
84 #[cfg(feature = "by_ref_proposal")]
85 let proposals = filter_out_invalid_group_extensions(
86 strategy,
87 proposals,
88 self.identity_provider,
89 commit_time,
90 )
91 .await?;
92
93 let proposals = filter_out_extra_group_context_extensions(strategy, proposals)?;
94 let proposals = filter_out_invalid_reinit(strategy, proposals, self.protocol_version)?;
95 let proposals = filter_out_reinit_if_other_proposals(strategy.is_ignore(), proposals)?;
96
97 let proposals = filter_out_external_init(strategy, proposals)?;
98
99 self.apply_proposal_changes(strategy, proposals, commit_time)
100 .await
101 }
102
103 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
apply_proposal_changes( &self, strategy: FilterStrategy, proposals: ProposalBundle, commit_time: Option<MlsTime>, ) -> Result<ApplyProposalsOutput, MlsError>104 pub(super) async fn apply_proposal_changes(
105 &self,
106 strategy: FilterStrategy,
107 proposals: ProposalBundle,
108 commit_time: Option<MlsTime>,
109 ) -> Result<ApplyProposalsOutput, MlsError> {
110 match proposals.group_context_extensions_proposal().cloned() {
111 Some(p) => {
112 self.apply_proposals_with_new_capabilities(strategy, proposals, p, commit_time)
113 .await
114 }
115 None => {
116 self.apply_tree_changes(
117 strategy,
118 proposals,
119 self.original_group_extensions,
120 commit_time,
121 )
122 .await
123 }
124 }
125 }
126
127 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
apply_tree_changes( &self, strategy: FilterStrategy, proposals: ProposalBundle, group_extensions_in_use: &ExtensionList, commit_time: Option<MlsTime>, ) -> Result<ApplyProposalsOutput, MlsError>128 pub(super) async fn apply_tree_changes(
129 &self,
130 strategy: FilterStrategy,
131 proposals: ProposalBundle,
132 group_extensions_in_use: &ExtensionList,
133 commit_time: Option<MlsTime>,
134 ) -> Result<ApplyProposalsOutput, MlsError> {
135 let mut applied_proposals = self
136 .validate_new_nodes(strategy, proposals, group_extensions_in_use, commit_time)
137 .await?;
138
139 let mut new_tree = self.original_tree.clone();
140
141 let added = new_tree
142 .batch_edit(
143 &mut applied_proposals,
144 group_extensions_in_use,
145 self.identity_provider,
146 self.cipher_suite_provider,
147 strategy.is_ignore(),
148 )
149 .await?;
150
151 let new_context_extensions = applied_proposals
152 .group_context_extensions_proposal()
153 .map(|gce| gce.proposal.clone());
154
155 Ok(ApplyProposalsOutput {
156 applied_proposals,
157 new_tree,
158 indexes_of_added_kpkgs: added,
159 external_init_index: None,
160 new_context_extensions,
161 })
162 }
163
164 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
validate_new_nodes( &self, strategy: FilterStrategy, mut proposals: ProposalBundle, group_extensions_in_use: &ExtensionList, commit_time: Option<MlsTime>, ) -> Result<ProposalBundle, MlsError>165 async fn validate_new_nodes(
166 &self,
167 strategy: FilterStrategy,
168 mut proposals: ProposalBundle,
169 group_extensions_in_use: &ExtensionList,
170 commit_time: Option<MlsTime>,
171 ) -> Result<ProposalBundle, MlsError> {
172 let leaf_node_validator = &LeafNodeValidator::new(
173 self.cipher_suite_provider,
174 self.identity_provider,
175 Some(group_extensions_in_use),
176 );
177
178 let bad_indices: Vec<_> = wrap_iter(proposals.update_proposals())
179 .zip(wrap_iter(proposals.update_proposal_senders()))
180 .enumerate()
181 .filter_map(|(i, (p, &sender_index))| async move {
182 let res = {
183 let leaf = &p.proposal.leaf_node;
184
185 let res = leaf_node_validator
186 .check_if_valid(
187 leaf,
188 ValidationContext::Update((self.group_id, *sender_index, commit_time)),
189 )
190 .await;
191
192 let old_leaf = match self.original_tree.get_leaf_node(sender_index) {
193 Ok(leaf) => leaf,
194 Err(e) => return Some(Err(e)),
195 };
196
197 let valid_successor = self
198 .identity_provider
199 .valid_successor(
200 &old_leaf.signing_identity,
201 &leaf.signing_identity,
202 group_extensions_in_use,
203 )
204 .await
205 .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))
206 .and_then(|valid| valid.then_some(()).ok_or(MlsError::InvalidSuccessor));
207
208 res.and(valid_successor)
209 };
210
211 apply_strategy(strategy, p.is_by_reference(), res)
212 .map(|b| (!b).then_some(i))
213 .transpose()
214 })
215 .try_collect()
216 .await?;
217
218 bad_indices.into_iter().rev().for_each(|i| {
219 proposals.remove::<UpdateProposal>(i);
220 proposals.update_senders.remove(i);
221 });
222
223 let bad_indices: Vec<_> = wrap_iter(proposals.add_proposals())
224 .enumerate()
225 .filter_map(|(i, p)| async move {
226 let res = self
227 .validate_new_node(leaf_node_validator, &p.proposal.key_package, commit_time)
228 .await;
229
230 apply_strategy(strategy, p.is_by_reference(), res)
231 .map(|b| (!b).then_some(i))
232 .transpose()
233 })
234 .try_collect()
235 .await?;
236
237 bad_indices
238 .into_iter()
239 .rev()
240 .for_each(|i| proposals.remove::<AddProposal>(i));
241
242 Ok(proposals)
243 }
244 }
245
246 #[derive(Clone, Copy, Debug)]
247 pub enum FilterStrategy {
248 IgnoreByRef,
249 IgnoreNone,
250 }
251
252 impl FilterStrategy {
ignore(self, by_ref: bool) -> bool253 pub(super) fn ignore(self, by_ref: bool) -> bool {
254 match self {
255 FilterStrategy::IgnoreByRef => by_ref,
256 FilterStrategy::IgnoreNone => false,
257 }
258 }
259
is_ignore(self) -> bool260 fn is_ignore(self) -> bool {
261 match self {
262 FilterStrategy::IgnoreByRef => true,
263 FilterStrategy::IgnoreNone => false,
264 }
265 }
266 }
267
apply_strategy( strategy: FilterStrategy, by_ref: bool, r: Result<(), MlsError>, ) -> Result<bool, MlsError>268 pub(crate) fn apply_strategy(
269 strategy: FilterStrategy,
270 by_ref: bool,
271 r: Result<(), MlsError>,
272 ) -> Result<bool, MlsError> {
273 r.map(|_| true)
274 .or_else(|error| strategy.ignore(by_ref).then_some(false).ok_or(error))
275 }
276
filter_out_update_for_committer( strategy: FilterStrategy, commit_sender: LeafIndex, mut proposals: ProposalBundle, ) -> Result<ProposalBundle, MlsError>277 fn filter_out_update_for_committer(
278 strategy: FilterStrategy,
279 commit_sender: LeafIndex,
280 mut proposals: ProposalBundle,
281 ) -> Result<ProposalBundle, MlsError> {
282 proposals.retain_by_type::<UpdateProposal, _, _>(|p| {
283 apply_strategy(
284 strategy,
285 p.is_by_reference(),
286 (p.sender != Sender::Member(*commit_sender))
287 .then_some(())
288 .ok_or(MlsError::InvalidCommitSelfUpdate),
289 )
290 })?;
291 Ok(proposals)
292 }
293
filter_out_removal_of_committer( strategy: FilterStrategy, commit_sender: LeafIndex, mut proposals: ProposalBundle, ) -> Result<ProposalBundle, MlsError>294 fn filter_out_removal_of_committer(
295 strategy: FilterStrategy,
296 commit_sender: LeafIndex,
297 mut proposals: ProposalBundle,
298 ) -> Result<ProposalBundle, MlsError> {
299 proposals.retain_by_type::<RemoveProposal, _, _>(|p| {
300 apply_strategy(
301 strategy,
302 p.is_by_reference(),
303 (p.proposal.to_remove != commit_sender)
304 .then_some(())
305 .ok_or(MlsError::CommitterSelfRemoval),
306 )
307 })?;
308 Ok(proposals)
309 }
310
311 #[cfg(feature = "by_ref_proposal")]
312 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
filter_out_invalid_group_extensions<C>( strategy: FilterStrategy, mut proposals: ProposalBundle, identity_provider: &C, commit_time: Option<MlsTime>, ) -> Result<ProposalBundle, MlsError> where C: IdentityProvider,313 async fn filter_out_invalid_group_extensions<C>(
314 strategy: FilterStrategy,
315 mut proposals: ProposalBundle,
316 identity_provider: &C,
317 commit_time: Option<MlsTime>,
318 ) -> Result<ProposalBundle, MlsError>
319 where
320 C: IdentityProvider,
321 {
322 let mut bad_indices = Vec::new();
323
324 for (i, p) in proposals.by_type::<ExtensionList>().enumerate() {
325 let ext = p.proposal.get_as::<ExternalSendersExt>();
326
327 let res = match ext {
328 Ok(None) => Ok(()),
329 Ok(Some(extension)) => extension
330 .verify_all(identity_provider, commit_time, &p.proposal)
331 .await
332 .map_err(|e| MlsError::IdentityProviderError(e.into_any_error())),
333 Err(e) => Err(MlsError::from(e)),
334 };
335
336 if !apply_strategy(strategy, p.is_by_reference(), res)? {
337 bad_indices.push(i);
338 }
339 }
340
341 bad_indices
342 .into_iter()
343 .rev()
344 .for_each(|i| proposals.remove::<ExtensionList>(i));
345
346 Ok(proposals)
347 }
348
filter_out_extra_group_context_extensions( strategy: FilterStrategy, mut proposals: ProposalBundle, ) -> Result<ProposalBundle, MlsError>349 fn filter_out_extra_group_context_extensions(
350 strategy: FilterStrategy,
351 mut proposals: ProposalBundle,
352 ) -> Result<ProposalBundle, MlsError> {
353 let mut found = false;
354
355 proposals.retain_by_type::<ExtensionList, _, _>(|p| {
356 apply_strategy(
357 strategy,
358 p.is_by_reference(),
359 (!core::mem::replace(&mut found, true))
360 .then_some(())
361 .ok_or(MlsError::MoreThanOneGroupContextExtensionsProposal),
362 )
363 })?;
364
365 Ok(proposals)
366 }
367
filter_out_invalid_reinit( strategy: FilterStrategy, mut proposals: ProposalBundle, protocol_version: ProtocolVersion, ) -> Result<ProposalBundle, MlsError>368 fn filter_out_invalid_reinit(
369 strategy: FilterStrategy,
370 mut proposals: ProposalBundle,
371 protocol_version: ProtocolVersion,
372 ) -> Result<ProposalBundle, MlsError> {
373 proposals.retain_by_type::<ReInitProposal, _, _>(|p| {
374 apply_strategy(
375 strategy,
376 p.is_by_reference(),
377 (p.proposal.version >= protocol_version)
378 .then_some(())
379 .ok_or(MlsError::InvalidProtocolVersionInReInit),
380 )
381 })?;
382
383 Ok(proposals)
384 }
385
filter_out_reinit_if_other_proposals( filter: bool, mut proposals: ProposalBundle, ) -> Result<ProposalBundle, MlsError>386 fn filter_out_reinit_if_other_proposals(
387 filter: bool,
388 mut proposals: ProposalBundle,
389 ) -> Result<ProposalBundle, MlsError> {
390 let proposal_count = proposals.length();
391
392 let has_reinit_and_other_proposal =
393 !proposals.reinit_proposals().is_empty() && proposal_count != 1;
394
395 if has_reinit_and_other_proposal {
396 let any_by_val = proposals.reinit_proposals().iter().any(|p| p.is_by_value());
397
398 if any_by_val || !filter {
399 return Err(MlsError::OtherProposalWithReInit);
400 }
401
402 let has_other_proposal_type = proposal_count > proposals.reinit_proposals().len();
403
404 if has_other_proposal_type {
405 proposals.reinitializations = Vec::new();
406 } else {
407 proposals.reinitializations.truncate(1);
408 }
409 }
410
411 Ok(proposals)
412 }
413
filter_out_external_init( strategy: FilterStrategy, mut proposals: ProposalBundle, ) -> Result<ProposalBundle, MlsError>414 fn filter_out_external_init(
415 strategy: FilterStrategy,
416 mut proposals: ProposalBundle,
417 ) -> Result<ProposalBundle, MlsError> {
418 proposals.retain_by_type::<ExternalInit, _, _>(|p| {
419 apply_strategy(
420 strategy,
421 p.is_by_reference(),
422 Err(MlsError::InvalidProposalTypeForSender),
423 )
424 })?;
425
426 Ok(proposals)
427 }
428
proposer_can_propose( proposer: Sender, proposal_type: ProposalType, by_ref: bool, ) -> Result<(), MlsError>429 pub(crate) fn proposer_can_propose(
430 proposer: Sender,
431 proposal_type: ProposalType,
432 by_ref: bool,
433 ) -> Result<(), MlsError> {
434 let can_propose = match (proposer, by_ref) {
435 (Sender::Member(_), false) => matches!(
436 proposal_type,
437 ProposalType::ADD
438 | ProposalType::REMOVE
439 | ProposalType::PSK
440 | ProposalType::RE_INIT
441 | ProposalType::GROUP_CONTEXT_EXTENSIONS
442 ),
443 (Sender::Member(_), true) => matches!(
444 proposal_type,
445 ProposalType::ADD
446 | ProposalType::UPDATE
447 | ProposalType::REMOVE
448 | ProposalType::PSK
449 | ProposalType::RE_INIT
450 | ProposalType::GROUP_CONTEXT_EXTENSIONS
451 ),
452 #[cfg(feature = "by_ref_proposal")]
453 (Sender::External(_), false) => false,
454 #[cfg(feature = "by_ref_proposal")]
455 (Sender::External(_), true) => matches!(
456 proposal_type,
457 ProposalType::ADD
458 | ProposalType::REMOVE
459 | ProposalType::RE_INIT
460 | ProposalType::PSK
461 | ProposalType::GROUP_CONTEXT_EXTENSIONS
462 ),
463 (Sender::NewMemberCommit, false) => matches!(
464 proposal_type,
465 ProposalType::REMOVE | ProposalType::PSK | ProposalType::EXTERNAL_INIT
466 ),
467 (Sender::NewMemberCommit, true) => false,
468 (Sender::NewMemberProposal, false) => false,
469 (Sender::NewMemberProposal, true) => matches!(proposal_type, ProposalType::ADD),
470 };
471
472 can_propose
473 .then_some(())
474 .ok_or(MlsError::InvalidProposalTypeForSender)
475 }
476
filter_out_invalid_proposers( strategy: FilterStrategy, mut proposals: ProposalBundle, ) -> Result<ProposalBundle, MlsError>477 pub(crate) fn filter_out_invalid_proposers(
478 strategy: FilterStrategy,
479 mut proposals: ProposalBundle,
480 ) -> Result<ProposalBundle, MlsError> {
481 for i in (0..proposals.add_proposals().len()).rev() {
482 let p = &proposals.add_proposals()[i];
483 let res = proposer_can_propose(p.sender, ProposalType::ADD, p.is_by_reference());
484
485 if !apply_strategy(strategy, p.is_by_reference(), res)? {
486 proposals.remove::<AddProposal>(i);
487 }
488 }
489
490 for i in (0..proposals.update_proposals().len()).rev() {
491 let p = &proposals.update_proposals()[i];
492 let res = proposer_can_propose(p.sender, ProposalType::UPDATE, p.is_by_reference());
493
494 if !apply_strategy(strategy, p.is_by_reference(), res)? {
495 proposals.remove::<UpdateProposal>(i);
496 proposals.update_senders.remove(i);
497 }
498 }
499
500 for i in (0..proposals.remove_proposals().len()).rev() {
501 let p = &proposals.remove_proposals()[i];
502 let res = proposer_can_propose(p.sender, ProposalType::REMOVE, p.is_by_reference());
503
504 if !apply_strategy(strategy, p.is_by_reference(), res)? {
505 proposals.remove::<RemoveProposal>(i);
506 }
507 }
508
509 #[cfg(feature = "psk")]
510 for i in (0..proposals.psk_proposals().len()).rev() {
511 let p = &proposals.psk_proposals()[i];
512 let res = proposer_can_propose(p.sender, ProposalType::PSK, p.is_by_reference());
513
514 if !apply_strategy(strategy, p.is_by_reference(), res)? {
515 proposals.remove::<PreSharedKeyProposal>(i);
516 }
517 }
518
519 for i in (0..proposals.reinit_proposals().len()).rev() {
520 let p = &proposals.reinit_proposals()[i];
521 let res = proposer_can_propose(p.sender, ProposalType::RE_INIT, p.is_by_reference());
522
523 if !apply_strategy(strategy, p.is_by_reference(), res)? {
524 proposals.remove::<ReInitProposal>(i);
525 }
526 }
527
528 for i in (0..proposals.external_init_proposals().len()).rev() {
529 let p = &proposals.external_init_proposals()[i];
530 let res = proposer_can_propose(p.sender, ProposalType::EXTERNAL_INIT, p.is_by_reference());
531
532 if !apply_strategy(strategy, p.is_by_reference(), res)? {
533 proposals.remove::<ExternalInit>(i);
534 }
535 }
536
537 for i in (0..proposals.group_context_ext_proposals().len()).rev() {
538 let p = &proposals.group_context_ext_proposals()[i];
539 let gce_type = ProposalType::GROUP_CONTEXT_EXTENSIONS;
540 let res = proposer_can_propose(p.sender, gce_type, p.is_by_reference());
541
542 if !apply_strategy(strategy, p.is_by_reference(), res)? {
543 proposals.remove::<ExtensionList>(i);
544 }
545 }
546
547 Ok(proposals)
548 }
549
leaf_index_of_update_sender(p: &ProposalInfo<UpdateProposal>) -> Result<LeafIndex, MlsError>550 fn leaf_index_of_update_sender(p: &ProposalInfo<UpdateProposal>) -> Result<LeafIndex, MlsError> {
551 match p.sender {
552 Sender::Member(i) => Ok(LeafIndex(i)),
553 _ => Err(MlsError::InvalidProposalTypeForSender),
554 }
555 }
556
557 #[cfg(feature = "custom_proposal")]
filter_out_unsupported_custom_proposals( proposals: &mut ProposalBundle, tree: &TreeKemPublic, strategy: FilterStrategy, ) -> Result<(), MlsError>558 pub(super) fn filter_out_unsupported_custom_proposals(
559 proposals: &mut ProposalBundle,
560 tree: &TreeKemPublic,
561 strategy: FilterStrategy,
562 ) -> Result<(), MlsError> {
563 let supported_types = proposals
564 .custom_proposal_types()
565 .filter(|t| tree.can_support_proposal(*t))
566 .collect_vec();
567
568 proposals.retain_custom(|p| {
569 let proposal_type = p.proposal.proposal_type();
570
571 apply_strategy(
572 strategy,
573 p.is_by_reference(),
574 supported_types
575 .contains(&proposal_type)
576 .then_some(())
577 .ok_or(MlsError::UnsupportedCustomProposal(proposal_type)),
578 )
579 })
580 }
581