1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5 use alloc::vec;
6 use alloc::vec::Vec;
7 #[cfg(feature = "std")]
8 use core::fmt::Display;
9 use itertools::Itertools;
10 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
11 use mls_rs_core::extension::ExtensionList;
12
13 use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider};
14
15 #[cfg(feature = "tree_index")]
16 use mls_rs_core::identity::SigningIdentity;
17
18 use math as tree_math;
19 use node::{LeafIndex, NodeIndex, NodeVec};
20
21 use self::leaf_node::LeafNode;
22
23 use crate::client::MlsError;
24 use crate::crypto::{self, CipherSuiteProvider, HpkeSecretKey};
25
26 #[cfg(feature = "by_ref_proposal")]
27 use crate::group::proposal::{AddProposal, UpdateProposal};
28
29 #[cfg(any(test, feature = "by_ref_proposal"))]
30 use crate::group::proposal::RemoveProposal;
31
32 use crate::group::proposal_filter::ProposalBundle;
33 use crate::tree_kem::tree_hash::TreeHashes;
34
35 mod capabilities;
36 pub(crate) mod hpke_encryption;
37 mod lifetime;
38 pub(crate) mod math;
39 pub mod node;
40 pub mod parent_hash;
41 pub mod path_secret;
42 mod private;
43 mod tree_hash;
44 pub mod tree_validator;
45 pub mod update_path;
46
47 pub use capabilities::*;
48 pub use lifetime::*;
49 pub(crate) use private::*;
50 pub use update_path::*;
51
52 use tree_index::*;
53
54 pub mod kem;
55 pub mod leaf_node;
56 pub mod leaf_node_validator;
57 mod tree_index;
58
59 #[cfg(feature = "std")]
60 pub(crate) mod tree_utils;
61
62 #[cfg(test)]
63 mod interop_test_vectors;
64
65 #[cfg(feature = "custom_proposal")]
66 use crate::group::proposal::ProposalType;
67
68 #[derive(Clone, Debug, MlsEncode, MlsDecode, MlsSize, Default)]
69 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
70 pub struct TreeKemPublic {
71 #[cfg(feature = "tree_index")]
72 #[cfg_attr(feature = "serde", serde(skip))]
73 index: TreeIndex,
74 pub(crate) nodes: NodeVec,
75 tree_hashes: TreeHashes,
76 }
77
78 impl PartialEq for TreeKemPublic {
eq(&self, other: &Self) -> bool79 fn eq(&self, other: &Self) -> bool {
80 self.nodes == other.nodes
81 }
82 }
83
84 impl TreeKemPublic {
new() -> TreeKemPublic85 pub fn new() -> TreeKemPublic {
86 Default::default()
87 }
88
89 #[cfg_attr(not(feature = "tree_index"), allow(unused))]
90 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
import_node_data<IP>( nodes: NodeVec, identity_provider: &IP, extensions: &ExtensionList, ) -> Result<TreeKemPublic, MlsError> where IP: IdentityProvider,91 pub(crate) async fn import_node_data<IP>(
92 nodes: NodeVec,
93 identity_provider: &IP,
94 extensions: &ExtensionList,
95 ) -> Result<TreeKemPublic, MlsError>
96 where
97 IP: IdentityProvider,
98 {
99 let mut tree = TreeKemPublic {
100 nodes,
101 ..Default::default()
102 };
103
104 #[cfg(feature = "tree_index")]
105 tree.initialize_index_if_necessary(identity_provider, extensions)
106 .await?;
107
108 Ok(tree)
109 }
110
111 #[cfg(feature = "tree_index")]
112 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
initialize_index_if_necessary<IP: IdentityProvider>( &mut self, identity_provider: &IP, extensions: &ExtensionList, ) -> Result<(), MlsError>113 pub(crate) async fn initialize_index_if_necessary<IP: IdentityProvider>(
114 &mut self,
115 identity_provider: &IP,
116 extensions: &ExtensionList,
117 ) -> Result<(), MlsError> {
118 if !self.index.is_initialized() {
119 self.index = TreeIndex::new();
120
121 for (leaf_index, leaf) in self.nodes.non_empty_leaves() {
122 index_insert(
123 &mut self.index,
124 leaf,
125 leaf_index,
126 identity_provider,
127 extensions,
128 )
129 .await?;
130 }
131 }
132
133 Ok(())
134 }
135
136 #[cfg(feature = "tree_index")]
get_leaf_node_with_identity(&self, identity: &[u8]) -> Option<LeafIndex>137 pub(crate) fn get_leaf_node_with_identity(&self, identity: &[u8]) -> Option<LeafIndex> {
138 self.index.get_leaf_index_with_identity(identity)
139 }
140
141 #[cfg(not(feature = "tree_index"))]
142 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_leaf_node_with_identity<I: IdentityProvider>( &self, identity: &[u8], id_provider: &I, extensions: &ExtensionList, ) -> Result<Option<LeafIndex>, MlsError>143 pub(crate) async fn get_leaf_node_with_identity<I: IdentityProvider>(
144 &self,
145 identity: &[u8],
146 id_provider: &I,
147 extensions: &ExtensionList,
148 ) -> Result<Option<LeafIndex>, MlsError> {
149 for (i, leaf) in self.nodes.non_empty_leaves() {
150 let leaf_id = id_provider
151 .identity(&leaf.signing_identity, extensions)
152 .await
153 .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
154
155 if leaf_id == identity {
156 return Ok(Some(i));
157 }
158 }
159
160 Ok(None)
161 }
162
163 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
derive<I: IdentityProvider>( leaf_node: LeafNode, secret_key: HpkeSecretKey, identity_provider: &I, extensions: &ExtensionList, ) -> Result<(TreeKemPublic, TreeKemPrivate), MlsError>164 pub async fn derive<I: IdentityProvider>(
165 leaf_node: LeafNode,
166 secret_key: HpkeSecretKey,
167 identity_provider: &I,
168 extensions: &ExtensionList,
169 ) -> Result<(TreeKemPublic, TreeKemPrivate), MlsError> {
170 let mut public_tree = TreeKemPublic::new();
171
172 public_tree
173 .add_leaf(leaf_node, identity_provider, extensions, None)
174 .await?;
175
176 let private_tree = TreeKemPrivate::new_self_leaf(LeafIndex(0), secret_key);
177
178 Ok((public_tree, private_tree))
179 }
180
total_leaf_count(&self) -> u32181 pub fn total_leaf_count(&self) -> u32 {
182 self.nodes.total_leaf_count()
183 }
184
185 #[cfg(any(test, all(feature = "custom_proposal", feature = "tree_index")))]
occupied_leaf_count(&self) -> u32186 pub fn occupied_leaf_count(&self) -> u32 {
187 self.nodes.occupied_leaf_count()
188 }
189
get_leaf_node(&self, index: LeafIndex) -> Result<&LeafNode, MlsError>190 pub fn get_leaf_node(&self, index: LeafIndex) -> Result<&LeafNode, MlsError> {
191 self.nodes.borrow_as_leaf(index)
192 }
193
find_leaf_node(&self, leaf_node: &LeafNode) -> Option<LeafIndex>194 pub fn find_leaf_node(&self, leaf_node: &LeafNode) -> Option<LeafIndex> {
195 self.nodes.non_empty_leaves().find_map(
196 |(index, node)| {
197 if node == leaf_node {
198 Some(index)
199 } else {
200 None
201 }
202 },
203 )
204 }
205
206 #[cfg(feature = "custom_proposal")]
can_support_proposal(&self, proposal_type: ProposalType) -> bool207 pub fn can_support_proposal(&self, proposal_type: ProposalType) -> bool {
208 #[cfg(feature = "tree_index")]
209 return self.index.count_supporting_proposal(proposal_type) == self.occupied_leaf_count();
210
211 #[cfg(not(feature = "tree_index"))]
212 self.nodes
213 .non_empty_leaves()
214 .all(|(_, l)| l.capabilities.proposals.contains(&proposal_type))
215 }
216
217 #[cfg(test)]
218 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
add_leaves<I: IdentityProvider, CP: CipherSuiteProvider>( &mut self, leaf_nodes: Vec<LeafNode>, id_provider: &I, cipher_suite_provider: &CP, ) -> Result<Vec<LeafIndex>, MlsError>219 pub async fn add_leaves<I: IdentityProvider, CP: CipherSuiteProvider>(
220 &mut self,
221 leaf_nodes: Vec<LeafNode>,
222 id_provider: &I,
223 cipher_suite_provider: &CP,
224 ) -> Result<Vec<LeafIndex>, MlsError> {
225 let mut start = LeafIndex(0);
226 let mut added = vec![];
227
228 for leaf in leaf_nodes.into_iter() {
229 start = self
230 .add_leaf(leaf, id_provider, &Default::default(), Some(start))
231 .await?;
232 added.push(start);
233 }
234
235 self.update_hashes(&added, cipher_suite_provider).await?;
236
237 Ok(added)
238 }
239
non_empty_leaves(&self) -> impl Iterator<Item = (LeafIndex, &LeafNode)> + '_240 pub fn non_empty_leaves(&self) -> impl Iterator<Item = (LeafIndex, &LeafNode)> + '_ {
241 self.nodes.non_empty_leaves()
242 }
243
244 #[cfg(feature = "prior_epoch")]
leaves(&self) -> impl Iterator<Item = Option<&LeafNode>> + '_245 pub fn leaves(&self) -> impl Iterator<Item = Option<&LeafNode>> + '_ {
246 self.nodes.leaves()
247 }
248
update_node( &mut self, pub_key: crypto::HpkePublicKey, index: NodeIndex, ) -> Result<(), MlsError>249 pub(crate) fn update_node(
250 &mut self,
251 pub_key: crypto::HpkePublicKey,
252 index: NodeIndex,
253 ) -> Result<(), MlsError> {
254 self.nodes
255 .borrow_or_fill_node_as_parent(index, &pub_key)
256 .map(|p| {
257 p.public_key = pub_key;
258 p.unmerged_leaves = vec![];
259 })
260 }
261
262 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
apply_update_path<IP, CP>( &mut self, sender: LeafIndex, update_path: &ValidatedUpdatePath, extensions: &ExtensionList, identity_provider: IP, cipher_suite_provider: &CP, ) -> Result<(), MlsError> where IP: IdentityProvider, CP: CipherSuiteProvider,263 pub(crate) async fn apply_update_path<IP, CP>(
264 &mut self,
265 sender: LeafIndex,
266 update_path: &ValidatedUpdatePath,
267 extensions: &ExtensionList,
268 identity_provider: IP,
269 cipher_suite_provider: &CP,
270 ) -> Result<(), MlsError>
271 where
272 IP: IdentityProvider,
273 CP: CipherSuiteProvider,
274 {
275 // Install the new leaf node
276 let existing_leaf = self.nodes.borrow_as_leaf_mut(sender)?;
277
278 #[cfg(feature = "tree_index")]
279 let original_leaf_node = existing_leaf.clone();
280
281 #[cfg(feature = "tree_index")]
282 let original_identity = identity_provider
283 .identity(&original_leaf_node.signing_identity, extensions)
284 .await
285 .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
286
287 *existing_leaf = update_path.leaf_node.clone();
288
289 // Update the rest of the nodes on the direct path
290 let path = self.nodes.direct_copath(sender);
291
292 for (node, pn) in update_path.nodes.iter().zip(path) {
293 node.as_ref()
294 .map(|n| self.update_node(n.public_key.clone(), pn.path))
295 .transpose()?;
296 }
297
298 #[cfg(feature = "tree_index")]
299 self.index.remove(&original_leaf_node, &original_identity);
300
301 index_insert(
302 #[cfg(feature = "tree_index")]
303 &mut self.index,
304 #[cfg(not(feature = "tree_index"))]
305 &self.nodes,
306 &update_path.leaf_node,
307 sender,
308 &identity_provider,
309 extensions,
310 )
311 .await?;
312
313 // Verify the parent hash of the new sender leaf node and update the parent hash values
314 // in the local tree
315 self.update_parent_hashes(sender, true, cipher_suite_provider)
316 .await?;
317
318 Ok(())
319 }
320
update_unmerged(&mut self, index: LeafIndex) -> Result<(), MlsError>321 fn update_unmerged(&mut self, index: LeafIndex) -> Result<(), MlsError> {
322 // For a given leaf index, find parent nodes and add the leaf to the unmerged leaf
323 self.nodes.direct_copath(index).into_iter().for_each(|i| {
324 if let Ok(p) = self.nodes.borrow_as_parent_mut(i.path) {
325 p.unmerged_leaves.push(index)
326 }
327 });
328
329 Ok(())
330 }
331
332 #[cfg(feature = "by_ref_proposal")]
333 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
batch_edit<I, CP>( &mut self, proposal_bundle: &mut ProposalBundle, extensions: &ExtensionList, id_provider: &I, cipher_suite_provider: &CP, filter: bool, ) -> Result<Vec<LeafIndex>, MlsError> where I: IdentityProvider, CP: CipherSuiteProvider,334 pub async fn batch_edit<I, CP>(
335 &mut self,
336 proposal_bundle: &mut ProposalBundle,
337 extensions: &ExtensionList,
338 id_provider: &I,
339 cipher_suite_provider: &CP,
340 filter: bool,
341 ) -> Result<Vec<LeafIndex>, MlsError>
342 where
343 I: IdentityProvider,
344 CP: CipherSuiteProvider,
345 {
346 // Apply removes (they commute with updates because they don't touch the same leaves)
347 for i in (0..proposal_bundle.remove_proposals().len()).rev() {
348 let index = proposal_bundle.remove_proposals()[i].proposal.to_remove;
349 let res = self.nodes.blank_leaf_node(index);
350
351 if res.is_ok() {
352 // This shouldn't fail if `blank_leaf_node` succedded.
353 self.nodes.blank_direct_path(index)?;
354 }
355
356 #[cfg(feature = "tree_index")]
357 if let Ok(old_leaf) = &res {
358 // If this fails, it's not because the proposal is bad.
359 let identity =
360 identity(&old_leaf.signing_identity, id_provider, extensions).await?;
361
362 self.index.remove(old_leaf, &identity);
363 }
364
365 if proposal_bundle.remove_proposals()[i].is_by_value() || !filter {
366 res?;
367 } else if res.is_err() {
368 proposal_bundle.remove::<RemoveProposal>(i);
369 }
370 }
371
372 // Remove from the tree old leaves from updates
373 let mut partial_updates = vec![];
374 let senders = proposal_bundle.update_senders.iter().copied();
375
376 for (i, (p, index)) in proposal_bundle.updates.iter().zip(senders).enumerate() {
377 let new_leaf = p.proposal.leaf_node.clone();
378
379 match self.nodes.blank_leaf_node(index) {
380 Ok(old_leaf) => {
381 #[cfg(feature = "tree_index")]
382 let old_id =
383 identity(&old_leaf.signing_identity, id_provider, extensions).await?;
384
385 #[cfg(feature = "tree_index")]
386 self.index.remove(&old_leaf, &old_id);
387
388 partial_updates.push((index, old_leaf, new_leaf, i));
389 }
390 _ => {
391 if !filter || !p.is_by_reference() {
392 return Err(MlsError::UpdatingNonExistingMember);
393 }
394 }
395 }
396 }
397
398 #[cfg(feature = "tree_index")]
399 let index_clone = self.index.clone();
400
401 let mut removed_leaves = vec![];
402 let mut updated_indices = vec![];
403 let mut bad_indices = vec![];
404
405 // Apply updates one by one. If there's an update which we can't apply or revert, we revert
406 // all updates.
407 for (index, old_leaf, new_leaf, i) in partial_updates.into_iter() {
408 #[cfg(feature = "tree_index")]
409 let res =
410 index_insert(&mut self.index, &new_leaf, index, id_provider, extensions).await;
411
412 #[cfg(not(feature = "tree_index"))]
413 let res = index_insert(&self.nodes, &new_leaf, index, id_provider, extensions).await;
414
415 let err = res.is_err();
416
417 if !filter {
418 res?;
419 }
420
421 if !err {
422 self.nodes.insert_leaf(index, new_leaf);
423 removed_leaves.push(old_leaf);
424 updated_indices.push(index);
425 } else {
426 #[cfg(feature = "tree_index")]
427 let res =
428 index_insert(&mut self.index, &old_leaf, index, id_provider, extensions).await;
429
430 #[cfg(not(feature = "tree_index"))]
431 let res =
432 index_insert(&self.nodes, &old_leaf, index, id_provider, extensions).await;
433
434 if res.is_ok() {
435 self.nodes.insert_leaf(index, old_leaf);
436 bad_indices.push(i);
437 } else {
438 // Revert all updates and stop. We're already in the "filter" case, so we don't throw an error.
439 #[cfg(feature = "tree_index")]
440 {
441 self.index = index_clone;
442 }
443
444 removed_leaves
445 .into_iter()
446 .zip(updated_indices.iter())
447 .for_each(|(leaf, index)| self.nodes.insert_leaf(*index, leaf));
448
449 updated_indices = vec![];
450 break;
451 }
452 }
453 }
454
455 // If we managed to update something, blank direct paths
456 updated_indices
457 .iter()
458 .try_for_each(|index| self.nodes.blank_direct_path(*index).map(|_| ()))?;
459
460 // Remove rejected updates from applied proposals
461 if updated_indices.is_empty() {
462 // This takes care of the "revert all" scenario
463 proposal_bundle.updates = vec![];
464 } else {
465 for i in bad_indices.into_iter().rev() {
466 proposal_bundle.remove::<UpdateProposal>(i);
467 proposal_bundle.update_senders.remove(i);
468 }
469 }
470
471 // Apply adds
472 let mut start = LeafIndex(0);
473 let mut added = vec![];
474 let mut bad_indexes = vec![];
475
476 for i in 0..proposal_bundle.additions.len() {
477 let leaf = proposal_bundle.additions[i]
478 .proposal
479 .key_package
480 .leaf_node
481 .clone();
482
483 let res = self
484 .add_leaf(leaf, id_provider, extensions, Some(start))
485 .await;
486
487 if let Ok(index) = res {
488 start = index;
489 added.push(start);
490 } else if proposal_bundle.additions[i].is_by_value() || !filter {
491 res?;
492 } else {
493 bad_indexes.push(i);
494 }
495 }
496
497 for i in bad_indexes.into_iter().rev() {
498 proposal_bundle.remove::<AddProposal>(i);
499 }
500
501 self.nodes.trim();
502
503 let updated_leaves = proposal_bundle
504 .remove_proposals()
505 .iter()
506 .map(|p| p.proposal.to_remove)
507 .chain(updated_indices)
508 .chain(added.iter().copied())
509 .collect_vec();
510
511 self.update_hashes(&updated_leaves, cipher_suite_provider)
512 .await?;
513
514 Ok(added)
515 }
516
517 #[cfg(not(feature = "by_ref_proposal"))]
518 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
batch_edit_lite<I, CP>( &mut self, proposal_bundle: &ProposalBundle, extensions: &ExtensionList, id_provider: &I, cipher_suite_provider: &CP, ) -> Result<Vec<LeafIndex>, MlsError> where I: IdentityProvider, CP: CipherSuiteProvider,519 pub async fn batch_edit_lite<I, CP>(
520 &mut self,
521 proposal_bundle: &ProposalBundle,
522 extensions: &ExtensionList,
523 id_provider: &I,
524 cipher_suite_provider: &CP,
525 ) -> Result<Vec<LeafIndex>, MlsError>
526 where
527 I: IdentityProvider,
528 CP: CipherSuiteProvider,
529 {
530 // Apply removes
531 for p in &proposal_bundle.removals {
532 let index = p.proposal.to_remove;
533
534 #[cfg(feature = "tree_index")]
535 {
536 // If this fails, it's not because the proposal is bad.
537 let old_leaf = self.nodes.blank_leaf_node(index)?;
538
539 let identity =
540 identity(&old_leaf.signing_identity, id_provider, extensions).await?;
541
542 self.index.remove(&old_leaf, &identity);
543 }
544
545 #[cfg(not(feature = "tree_index"))]
546 self.nodes.blank_leaf_node(index)?;
547
548 self.nodes.blank_direct_path(index)?;
549 }
550
551 // Apply adds
552 let mut start = LeafIndex(0);
553 let mut added = vec![];
554
555 for p in &proposal_bundle.additions {
556 let leaf = p.proposal.key_package.leaf_node.clone();
557 start = self
558 .add_leaf(leaf, id_provider, extensions, Some(start))
559 .await?;
560 added.push(start);
561 }
562
563 self.nodes.trim();
564
565 let updated_leaves = proposal_bundle
566 .remove_proposals()
567 .iter()
568 .map(|p| p.proposal.to_remove)
569 .chain(added.iter().copied())
570 .collect_vec();
571
572 self.update_hashes(&updated_leaves, cipher_suite_provider)
573 .await?;
574
575 Ok(added)
576 }
577
578 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
add_leaf<I: IdentityProvider>( &mut self, leaf: LeafNode, id_provider: &I, extensions: &ExtensionList, start: Option<LeafIndex>, ) -> Result<LeafIndex, MlsError>579 pub(crate) async fn add_leaf<I: IdentityProvider>(
580 &mut self,
581 leaf: LeafNode,
582 id_provider: &I,
583 extensions: &ExtensionList,
584 start: Option<LeafIndex>,
585 ) -> Result<LeafIndex, MlsError> {
586 let index = self.nodes.next_empty_leaf(start.unwrap_or(LeafIndex(0)));
587
588 #[cfg(feature = "tree_index")]
589 index_insert(&mut self.index, &leaf, index, id_provider, extensions).await?;
590
591 #[cfg(not(feature = "tree_index"))]
592 index_insert(&self.nodes, &leaf, index, id_provider, extensions).await?;
593
594 self.nodes.insert_leaf(index, leaf);
595 self.update_unmerged(index)?;
596
597 Ok(index)
598 }
599 }
600
601 #[cfg(feature = "tree_index")]
602 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
identity<I: IdentityProvider>( signing_id: &SigningIdentity, provider: &I, extensions: &ExtensionList, ) -> Result<Vec<u8>, MlsError>603 async fn identity<I: IdentityProvider>(
604 signing_id: &SigningIdentity,
605 provider: &I,
606 extensions: &ExtensionList,
607 ) -> Result<Vec<u8>, MlsError> {
608 provider
609 .identity(signing_id, extensions)
610 .await
611 .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))
612 }
613
614 #[cfg(feature = "std")]
615 impl Display for TreeKemPublic {
fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result616 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
617 write!(f, "{}", tree_utils::build_ascii_tree(&self.nodes))
618 }
619 }
620
621 #[cfg(test)]
622 use crate::group::{proposal::Proposal, proposal_filter::ProposalSource, Sender};
623
624 #[cfg(test)]
625 impl TreeKemPublic {
626 #[cfg(feature = "by_ref_proposal")]
627 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
update_leaf<I, CP>( &mut self, leaf_index: u32, leaf_node: LeafNode, identity_provider: &I, cipher_suite_provider: &CP, ) -> Result<(), MlsError> where I: IdentityProvider, CP: CipherSuiteProvider,628 pub async fn update_leaf<I, CP>(
629 &mut self,
630 leaf_index: u32,
631 leaf_node: LeafNode,
632 identity_provider: &I,
633 cipher_suite_provider: &CP,
634 ) -> Result<(), MlsError>
635 where
636 I: IdentityProvider,
637 CP: CipherSuiteProvider,
638 {
639 let p = Proposal::Update(UpdateProposal { leaf_node });
640
641 let mut bundle = ProposalBundle::default();
642 bundle.add(p, Sender::Member(leaf_index), ProposalSource::ByValue);
643 bundle.update_senders = vec![LeafIndex(leaf_index)];
644
645 self.batch_edit(
646 &mut bundle,
647 &Default::default(),
648 identity_provider,
649 cipher_suite_provider,
650 true,
651 )
652 .await?;
653
654 Ok(())
655 }
656
657 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
remove_leaves<I, CP>( &mut self, indexes: Vec<LeafIndex>, identity_provider: &I, cipher_suite_provider: &CP, ) -> Result<Vec<(LeafIndex, LeafNode)>, MlsError> where I: IdentityProvider, CP: CipherSuiteProvider,658 pub async fn remove_leaves<I, CP>(
659 &mut self,
660 indexes: Vec<LeafIndex>,
661 identity_provider: &I,
662 cipher_suite_provider: &CP,
663 ) -> Result<Vec<(LeafIndex, LeafNode)>, MlsError>
664 where
665 I: IdentityProvider,
666 CP: CipherSuiteProvider,
667 {
668 let old_tree = self.clone();
669
670 let proposals = indexes
671 .iter()
672 .copied()
673 .map(|to_remove| Proposal::Remove(RemoveProposal { to_remove }));
674
675 let mut bundle = ProposalBundle::default();
676
677 for p in proposals {
678 bundle.add(p, Sender::Member(0), ProposalSource::ByValue);
679 }
680
681 #[cfg(feature = "by_ref_proposal")]
682 self.batch_edit(
683 &mut bundle,
684 &Default::default(),
685 identity_provider,
686 cipher_suite_provider,
687 true,
688 )
689 .await?;
690
691 #[cfg(not(feature = "by_ref_proposal"))]
692 self.batch_edit_lite(
693 &bundle,
694 &Default::default(),
695 identity_provider,
696 cipher_suite_provider,
697 )
698 .await?;
699
700 bundle
701 .removals
702 .iter()
703 .map(|p| {
704 let index = p.proposal.to_remove;
705 let leaf = old_tree.get_leaf_node(index)?.clone();
706 Ok((index, leaf))
707 })
708 .collect()
709 }
710
get_leaf_nodes(&self) -> Vec<&LeafNode>711 pub fn get_leaf_nodes(&self) -> Vec<&LeafNode> {
712 self.nodes.non_empty_leaves().map(|(_, l)| l).collect()
713 }
714 }
715
716 #[cfg(test)]
717 pub(crate) mod test_utils {
718 use crate::crypto::test_utils::TestCryptoProvider;
719 use crate::signer::Signable;
720 use alloc::vec::Vec;
721 use alloc::{format, vec};
722 use mls_rs_core::crypto::CipherSuiteProvider;
723 use mls_rs_core::group::Capabilities;
724 use mls_rs_core::identity::BasicCredential;
725
726 use crate::identity::test_utils::get_test_signing_identity;
727 use crate::{
728 cipher_suite::CipherSuite,
729 crypto::{HpkeSecretKey, SignatureSecretKey},
730 identity::basic::BasicIdentityProvider,
731 tree_kem::leaf_node::test_utils::get_basic_test_node_sig_key,
732 };
733
734 use super::leaf_node::{ConfigProperties, LeafNodeSigningContext};
735 use super::node::LeafIndex;
736 use super::Lifetime;
737 use super::{
738 leaf_node::{test_utils::get_basic_test_node, LeafNode},
739 TreeKemPrivate, TreeKemPublic,
740 };
741
742 #[derive(Debug)]
743 pub(crate) struct TestTree {
744 pub public: TreeKemPublic,
745 pub private: TreeKemPrivate,
746 pub creator_leaf: LeafNode,
747 pub creator_signing_key: SignatureSecretKey,
748 pub creator_hpke_secret: HpkeSecretKey,
749 }
750
751 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_tree(cipher_suite: CipherSuite) -> TestTree752 pub(crate) async fn get_test_tree(cipher_suite: CipherSuite) -> TestTree {
753 let (creator_leaf, creator_hpke_secret, creator_signing_key) =
754 get_basic_test_node_sig_key(cipher_suite, "creator").await;
755
756 let (test_public, test_private) = TreeKemPublic::derive(
757 creator_leaf.clone(),
758 creator_hpke_secret.clone(),
759 &BasicIdentityProvider,
760 &Default::default(),
761 )
762 .await
763 .unwrap();
764
765 TestTree {
766 public: test_public,
767 private: test_private,
768 creator_leaf,
769 creator_signing_key,
770 creator_hpke_secret,
771 }
772 }
773
774 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_leaf_nodes(cipher_suite: CipherSuite) -> Vec<LeafNode>775 pub async fn get_test_leaf_nodes(cipher_suite: CipherSuite) -> Vec<LeafNode> {
776 [
777 get_basic_test_node(cipher_suite, "A").await,
778 get_basic_test_node(cipher_suite, "B").await,
779 get_basic_test_node(cipher_suite, "C").await,
780 ]
781 .to_vec()
782 }
783
784 impl TreeKemPublic {
785 #[cfg(feature = "tree_index")]
equal_internals(&self, other: &TreeKemPublic) -> bool786 pub fn equal_internals(&self, other: &TreeKemPublic) -> bool {
787 self.tree_hashes == other.tree_hashes && self.index == other.index
788 }
789 }
790
791 #[derive(Debug, Clone)]
792 pub struct TreeWithSigners {
793 pub tree: TreeKemPublic,
794 pub signers: Vec<Option<SignatureSecretKey>>,
795 pub group_id: Vec<u8>,
796 }
797
798 impl TreeWithSigners {
799 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
make_full_tree<P: CipherSuiteProvider>( n_leaves: u32, cs: &P, ) -> TreeWithSigners800 pub async fn make_full_tree<P: CipherSuiteProvider>(
801 n_leaves: u32,
802 cs: &P,
803 ) -> TreeWithSigners {
804 let mut tree = TreeWithSigners {
805 tree: TreeKemPublic::new(),
806 signers: vec![],
807 group_id: cs.random_bytes_vec(cs.kdf_extract_size()).unwrap(),
808 };
809
810 tree.add_member("Alice", cs).await;
811
812 // A adds B, B adds C, C adds D etc.
813 for i in 1..n_leaves {
814 tree.add_member(&format!("Alice{i}"), cs).await;
815 tree.update_committer_path(i - 1, cs).await;
816 }
817
818 tree
819 }
820
821 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
add_member<P: CipherSuiteProvider>(&mut self, name: &str, cs: &P)822 pub async fn add_member<P: CipherSuiteProvider>(&mut self, name: &str, cs: &P) {
823 let (leaf, signer) = make_leaf(name, cs).await;
824 let index = self.tree.nodes.next_empty_leaf(LeafIndex(0));
825 self.tree.nodes.insert_leaf(index, leaf);
826 self.tree.update_unmerged(index).unwrap();
827 let index = *index as usize;
828
829 match self.signers.len() {
830 l if l == index => self.signers.push(Some(signer)),
831 l if l > index => self.signers[index] = Some(signer),
832 _ => panic!("signer tree size mismatch"),
833 }
834 }
835
836 #[cfg(feature = "rfc_compliant")]
837 #[cfg_attr(coverage_nightly, coverage(off))]
remove_member(&mut self, member: u32)838 pub fn remove_member(&mut self, member: u32) {
839 self.tree
840 .nodes
841 .blank_direct_path(LeafIndex(member))
842 .unwrap();
843
844 self.tree.nodes.blank_leaf_node(LeafIndex(member)).unwrap();
845
846 *self
847 .signers
848 .get_mut(member as usize)
849 .expect("signer tree size mismatch") = None;
850 }
851
852 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
update_committer_path<P: CipherSuiteProvider>( &mut self, committer: u32, cs: &P, )853 pub async fn update_committer_path<P: CipherSuiteProvider>(
854 &mut self,
855 committer: u32,
856 cs: &P,
857 ) {
858 let committer = LeafIndex(committer);
859
860 let path = self.tree.nodes.direct_copath(committer);
861 let filtered = self.tree.nodes.filtered(committer).unwrap();
862
863 for (n, f) in path.into_iter().zip(filtered) {
864 if !f {
865 self.tree
866 .update_node(cs.kem_generate().await.unwrap().1, n.path)
867 .unwrap();
868 }
869 }
870
871 self.tree.tree_hashes.current = vec![];
872 self.tree.tree_hash(cs).await.unwrap();
873
874 self.tree
875 .update_parent_hashes(committer, false, cs)
876 .await
877 .unwrap();
878
879 self.tree.tree_hashes.current = vec![];
880 self.tree.tree_hash(cs).await.unwrap();
881
882 let context = LeafNodeSigningContext {
883 group_id: Some(&self.group_id),
884 leaf_index: Some(*committer),
885 };
886
887 let signer = self.signers[*committer as usize].as_ref().unwrap();
888
889 self.tree
890 .nodes
891 .borrow_as_leaf_mut(committer)
892 .unwrap()
893 .sign(cs, signer, &context)
894 .await
895 .unwrap();
896
897 self.tree.tree_hashes.current = vec![];
898 self.tree.tree_hash(cs).await.unwrap();
899 }
900 }
901
902 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
make_leaf<P: CipherSuiteProvider>( name: &str, cs: &P, ) -> (LeafNode, SignatureSecretKey)903 pub async fn make_leaf<P: CipherSuiteProvider>(
904 name: &str,
905 cs: &P,
906 ) -> (LeafNode, SignatureSecretKey) {
907 let (signing_identity, signature_key) =
908 get_test_signing_identity(cs.cipher_suite(), name.as_bytes()).await;
909
910 let capabilities = Capabilities {
911 credentials: vec![BasicCredential::credential_type()],
912 cipher_suites: TestCryptoProvider::all_supported_cipher_suites(),
913 ..Default::default()
914 };
915
916 let properties = ConfigProperties {
917 capabilities,
918 extensions: Default::default(),
919 };
920
921 let (leaf, _) = LeafNode::generate(
922 cs,
923 properties,
924 signing_identity,
925 &signature_key,
926 Lifetime::years(1).unwrap(),
927 )
928 .await
929 .unwrap();
930
931 (leaf, signature_key)
932 }
933 }
934
935 #[cfg(test)]
936 mod tests {
937 use crate::client::test_utils::TEST_CIPHER_SUITE;
938 use crate::crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider};
939
940 #[cfg(feature = "custom_proposal")]
941 use crate::group::proposal::ProposalType;
942
943 use crate::identity::basic::BasicIdentityProvider;
944 use crate::tree_kem::leaf_node::LeafNode;
945 use crate::tree_kem::node::{LeafIndex, Node, NodeIndex, NodeTypeResolver, Parent};
946 use crate::tree_kem::parent_hash::ParentHash;
947 use crate::tree_kem::test_utils::{get_test_leaf_nodes, get_test_tree};
948 use crate::tree_kem::{MlsError, TreeKemPublic};
949 use alloc::borrow::ToOwned;
950 use alloc::vec;
951 use alloc::vec::Vec;
952 use assert_matches::assert_matches;
953
954 #[cfg(feature = "by_ref_proposal")]
955 use alloc::boxed::Box;
956
957 #[cfg(feature = "by_ref_proposal")]
958 use crate::{
959 client::test_utils::TEST_PROTOCOL_VERSION,
960 group::{
961 proposal::{Proposal, RemoveProposal, UpdateProposal},
962 proposal_filter::{ProposalBundle, ProposalSource},
963 proposal_ref::ProposalRef,
964 Sender,
965 },
966 key_package::test_utils::test_key_package,
967 };
968
969 #[cfg(any(feature = "by_ref_proposal", feature = "custo_proposal"))]
970 use crate::tree_kem::leaf_node::test_utils::get_basic_test_node;
971
972 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_derive()973 async fn test_derive() {
974 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
975 let test_tree = get_test_tree(cipher_suite).await;
976
977 assert_eq!(
978 test_tree.public.nodes[0],
979 Some(Node::Leaf(test_tree.creator_leaf.clone()))
980 );
981
982 assert_eq!(test_tree.private.self_index, LeafIndex(0));
983
984 assert_eq!(
985 test_tree.private.secret_keys[0],
986 Some(test_tree.creator_hpke_secret)
987 );
988 }
989 }
990
991 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_import_export()992 async fn test_import_export() {
993 let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
994 let mut test_tree = get_test_tree(TEST_CIPHER_SUITE).await;
995
996 let additional_key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
997
998 test_tree
999 .public
1000 .add_leaves(
1001 additional_key_packages,
1002 &BasicIdentityProvider,
1003 &cipher_suite_provider,
1004 )
1005 .await
1006 .unwrap();
1007
1008 let imported = TreeKemPublic::import_node_data(
1009 test_tree.public.nodes.clone(),
1010 &BasicIdentityProvider,
1011 &Default::default(),
1012 )
1013 .await
1014 .unwrap();
1015
1016 assert_eq!(test_tree.public.nodes, imported.nodes);
1017
1018 #[cfg(feature = "tree_index")]
1019 assert_eq!(test_tree.public.index, imported.index);
1020 }
1021
1022 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_add_leaf()1023 async fn test_add_leaf() {
1024 let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1025 let mut tree = TreeKemPublic::new();
1026
1027 let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1028
1029 let res = tree
1030 .add_leaves(
1031 leaf_nodes.clone(),
1032 &BasicIdentityProvider,
1033 &cipher_suite_provider,
1034 )
1035 .await
1036 .unwrap();
1037
1038 // The leaf count should be equal to the number of packages we added
1039 assert_eq!(res.len(), leaf_nodes.len());
1040 assert_eq!(tree.occupied_leaf_count(), leaf_nodes.len() as u32);
1041
1042 // Each added package should be at the proper index and searchable in the tree
1043 res.into_iter().zip(leaf_nodes.clone()).for_each(|(r, kp)| {
1044 assert_eq!(tree.get_leaf_node(r).unwrap(), &kp);
1045 });
1046
1047 // Verify the underlying state
1048 #[cfg(feature = "tree_index")]
1049 assert_eq!(tree.index.len(), tree.occupied_leaf_count() as usize);
1050
1051 assert_eq!(tree.nodes.len(), 5);
1052 assert_eq!(tree.nodes[0], leaf_nodes[0].clone().into());
1053 assert_eq!(tree.nodes[1], None);
1054 assert_eq!(tree.nodes[2], leaf_nodes[1].clone().into());
1055 assert_eq!(tree.nodes[3], None);
1056 assert_eq!(tree.nodes[4], leaf_nodes[2].clone().into());
1057 }
1058
1059 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_get_key_packages()1060 async fn test_get_key_packages() {
1061 let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1062 let mut tree = TreeKemPublic::new();
1063
1064 let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1065
1066 tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
1067 .await
1068 .unwrap();
1069
1070 let key_packages = tree.get_leaf_nodes();
1071 assert_eq!(key_packages, key_packages.to_owned());
1072 }
1073
1074 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_add_leaf_duplicate()1075 async fn test_add_leaf_duplicate() {
1076 let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1077 let mut tree = TreeKemPublic::new();
1078
1079 let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1080
1081 tree.add_leaves(
1082 key_packages.clone(),
1083 &BasicIdentityProvider,
1084 &cipher_suite_provider,
1085 )
1086 .await
1087 .unwrap();
1088
1089 let res = tree
1090 .add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
1091 .await;
1092
1093 assert_matches!(res, Err(MlsError::DuplicateLeafData(_)));
1094 }
1095
1096 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_add_leaf_empty_leaf()1097 async fn test_add_leaf_empty_leaf() {
1098 let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1099 let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1100 let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1101
1102 tree.add_leaves(
1103 [key_packages[0].clone()].to_vec(),
1104 &BasicIdentityProvider,
1105 &cipher_suite_provider,
1106 )
1107 .await
1108 .unwrap();
1109
1110 tree.nodes[0] = None; // Set the original first node to none
1111 //
1112 tree.add_leaves(
1113 [key_packages[1].clone()].to_vec(),
1114 &BasicIdentityProvider,
1115 &cipher_suite_provider,
1116 )
1117 .await
1118 .unwrap();
1119
1120 assert_eq!(tree.nodes[0], key_packages[1].clone().into());
1121 assert_eq!(tree.nodes[1], None);
1122 assert_eq!(tree.nodes[2], key_packages[0].clone().into());
1123 assert_eq!(tree.nodes.len(), 3)
1124 }
1125
1126 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_add_leaf_unmerged()1127 async fn test_add_leaf_unmerged() {
1128 let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1129 let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1130 let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1131
1132 tree.add_leaves(
1133 [key_packages[0].clone(), key_packages[1].clone()].to_vec(),
1134 &BasicIdentityProvider,
1135 &cipher_suite_provider,
1136 )
1137 .await
1138 .unwrap();
1139
1140 tree.nodes[3] = Parent {
1141 public_key: vec![].into(),
1142 parent_hash: ParentHash::empty(),
1143 unmerged_leaves: vec![],
1144 }
1145 .into();
1146
1147 tree.add_leaves(
1148 [key_packages[2].clone()].to_vec(),
1149 &BasicIdentityProvider,
1150 &cipher_suite_provider,
1151 )
1152 .await
1153 .unwrap();
1154
1155 assert_eq!(
1156 tree.nodes[3].as_parent().unwrap().unmerged_leaves,
1157 vec![LeafIndex(3)]
1158 )
1159 }
1160
1161 #[cfg(feature = "by_ref_proposal")]
1162 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_update_leaf()1163 async fn test_update_leaf() {
1164 let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1165 // Create a tree
1166 let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1167
1168 let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1169
1170 tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
1171 .await
1172 .unwrap();
1173
1174 // Add in parent nodes so we can detect them clearing after update
1175 tree.nodes.direct_copath(LeafIndex(0)).iter().for_each(|n| {
1176 tree.nodes
1177 .borrow_or_fill_node_as_parent(n.path, &b"pub_key".to_vec().into())
1178 .unwrap();
1179 });
1180
1181 let original_size = tree.occupied_leaf_count();
1182 let original_leaf_index = LeafIndex(1);
1183
1184 let updated_leaf = get_basic_test_node(TEST_CIPHER_SUITE, "A").await;
1185
1186 tree.update_leaf(
1187 *original_leaf_index,
1188 updated_leaf.clone(),
1189 &BasicIdentityProvider,
1190 &cipher_suite_provider,
1191 )
1192 .await
1193 .unwrap();
1194
1195 // The tree should not have grown due to an update
1196 assert_eq!(tree.occupied_leaf_count(), original_size);
1197
1198 // The cache of tree package indexes should not have grown
1199 #[cfg(feature = "tree_index")]
1200 assert_eq!(tree.index.len() as u32, tree.occupied_leaf_count());
1201
1202 // The key package should be updated in the tree
1203 assert_eq!(
1204 tree.get_leaf_node(original_leaf_index).unwrap(),
1205 &updated_leaf
1206 );
1207
1208 // Verify that the direct path has been cleared
1209 tree.nodes.direct_copath(LeafIndex(0)).iter().for_each(|n| {
1210 assert!(tree.nodes[n.path as usize].is_none());
1211 });
1212 }
1213
1214 #[cfg(feature = "by_ref_proposal")]
1215 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_update_leaf_not_found()1216 async fn test_update_leaf_not_found() {
1217 let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1218
1219 // Create a tree
1220 let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1221
1222 let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1223
1224 tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
1225 .await
1226 .unwrap();
1227
1228 let new_key_package = get_basic_test_node(TEST_CIPHER_SUITE, "new").await;
1229
1230 let res = tree
1231 .update_leaf(
1232 128,
1233 new_key_package,
1234 &BasicIdentityProvider,
1235 &cipher_suite_provider,
1236 )
1237 .await;
1238
1239 assert_matches!(res, Err(MlsError::UpdatingNonExistingMember));
1240 }
1241
1242 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_remove_leaf()1243 async fn test_remove_leaf() {
1244 let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1245
1246 // Create a tree
1247 let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1248 let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1249
1250 let indexes = tree
1251 .add_leaves(
1252 key_packages.clone(),
1253 &BasicIdentityProvider,
1254 &cipher_suite_provider,
1255 )
1256 .await
1257 .unwrap();
1258
1259 let original_leaf_count = tree.occupied_leaf_count();
1260
1261 // Remove two leaves from the tree
1262 let expected_result: Vec<(LeafIndex, LeafNode)> =
1263 indexes.clone().into_iter().zip(key_packages).collect();
1264
1265 let res = tree
1266 .remove_leaves(
1267 indexes.clone(),
1268 &BasicIdentityProvider,
1269 &cipher_suite_provider,
1270 )
1271 .await
1272 .unwrap();
1273
1274 // The order may change
1275 assert!(res.iter().all(|x| expected_result.contains(x)));
1276 assert!(expected_result.iter().all(|x| res.contains(x)));
1277
1278 // The leaves should be removed from the tree
1279 assert_eq!(
1280 tree.occupied_leaf_count(),
1281 original_leaf_count - indexes.len() as u32
1282 );
1283 }
1284
1285 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_remove_leaf_middle()1286 async fn test_remove_leaf_middle() {
1287 let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1288
1289 // Create a tree
1290 let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1291 let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1292
1293 let to_remove = tree
1294 .add_leaves(
1295 leaf_nodes.clone(),
1296 &BasicIdentityProvider,
1297 &cipher_suite_provider,
1298 )
1299 .await
1300 .unwrap()[0];
1301
1302 let original_leaf_count = tree.occupied_leaf_count();
1303
1304 let res = tree
1305 .remove_leaves(
1306 vec![to_remove],
1307 &BasicIdentityProvider,
1308 &cipher_suite_provider,
1309 )
1310 .await
1311 .unwrap();
1312
1313 assert_eq!(res, vec![(to_remove, leaf_nodes[0].clone())]);
1314
1315 // The leaf count should have been reduced by 1
1316 assert_eq!(tree.occupied_leaf_count(), original_leaf_count - 1);
1317
1318 // There should be a blank in the tree
1319 assert_eq!(
1320 tree.nodes.get(NodeIndex::from(to_remove) as usize).unwrap(),
1321 &None
1322 );
1323 }
1324
1325 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_create_blanks()1326 async fn test_create_blanks() {
1327 let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1328
1329 // Create a tree
1330 let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1331
1332 let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1333
1334 tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
1335 .await
1336 .unwrap();
1337
1338 let original_leaf_count = tree.occupied_leaf_count();
1339
1340 let to_remove = vec![LeafIndex(2)];
1341
1342 // Remove the leaf from the tree
1343 tree.remove_leaves(to_remove, &BasicIdentityProvider, &cipher_suite_provider)
1344 .await
1345 .unwrap();
1346
1347 // The occupied leaf count should have been reduced by 1
1348 assert_eq!(tree.occupied_leaf_count(), original_leaf_count - 1);
1349
1350 // The total leaf count should remain unchanged
1351 assert_eq!(tree.total_leaf_count(), original_leaf_count);
1352
1353 // The location of key_packages[1] should now be blank
1354 let removed_location = tree
1355 .nodes
1356 .get(NodeIndex::from(LeafIndex(2)) as usize)
1357 .unwrap();
1358
1359 assert_eq!(removed_location, &None);
1360 }
1361
1362 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_remove_leaf_failure()1363 async fn test_remove_leaf_failure() {
1364 let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1365
1366 // Create a tree
1367 let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1368
1369 let res = tree
1370 .remove_leaves(
1371 vec![LeafIndex(128)],
1372 &BasicIdentityProvider,
1373 &cipher_suite_provider,
1374 )
1375 .await;
1376
1377 assert_matches!(res, Err(MlsError::InvalidNodeIndex(256)));
1378 }
1379
1380 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_find_leaf_node()1381 async fn test_find_leaf_node() {
1382 let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1383 // Create a tree
1384 let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1385
1386 let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1387
1388 tree.add_leaves(
1389 leaf_nodes.clone(),
1390 &BasicIdentityProvider,
1391 &cipher_suite_provider,
1392 )
1393 .await
1394 .unwrap();
1395
1396 // Find each node
1397 for (i, leaf_node) in leaf_nodes.iter().enumerate() {
1398 let expected_index = LeafIndex(i as u32 + 1);
1399 assert_eq!(tree.find_leaf_node(leaf_node), Some(expected_index));
1400 }
1401 }
1402
1403 // TODO add test for the lite version
1404 #[cfg(feature = "by_ref_proposal")]
1405 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
batch_edit_works()1406 async fn batch_edit_works() {
1407 let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1408
1409 let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1410 let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1411
1412 tree.add_leaves(leaf_nodes, &BasicIdentityProvider, &cipher_suite_provider)
1413 .await
1414 .unwrap();
1415
1416 let mut bundle = ProposalBundle::default();
1417
1418 let kp = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "D").await;
1419 let add = Proposal::Add(Box::new(kp.into()));
1420
1421 bundle.add(add, Sender::Member(0), ProposalSource::ByValue);
1422
1423 let update = UpdateProposal {
1424 leaf_node: get_basic_test_node(TEST_CIPHER_SUITE, "A").await,
1425 };
1426
1427 let update = Proposal::Update(update);
1428 let pref = ProposalRef::new_fake(vec![1, 2, 3]);
1429
1430 bundle.add(update, Sender::Member(1), ProposalSource::ByReference(pref));
1431
1432 bundle.update_senders = vec![LeafIndex(1)];
1433
1434 let remove = RemoveProposal {
1435 to_remove: LeafIndex(2),
1436 };
1437
1438 let remove = Proposal::Remove(remove);
1439
1440 bundle.add(remove, Sender::Member(0), ProposalSource::ByValue);
1441
1442 tree.batch_edit(
1443 &mut bundle,
1444 &Default::default(),
1445 &BasicIdentityProvider,
1446 &cipher_suite_provider,
1447 true,
1448 )
1449 .await
1450 .unwrap();
1451
1452 assert_eq!(bundle.add_proposals().len(), 1);
1453 assert_eq!(bundle.remove_proposals().len(), 1);
1454 assert_eq!(bundle.update_proposals().len(), 1);
1455 }
1456
1457 #[cfg(feature = "custom_proposal")]
1458 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
custom_proposal_support()1459 async fn custom_proposal_support() {
1460 let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1461 let mut tree = TreeKemPublic::new();
1462
1463 let test_proposal_type = ProposalType::from(42);
1464
1465 let mut leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1466
1467 leaf_nodes
1468 .iter_mut()
1469 .for_each(|n| n.capabilities.proposals.push(test_proposal_type));
1470
1471 tree.add_leaves(leaf_nodes, &BasicIdentityProvider, &cipher_suite_provider)
1472 .await
1473 .unwrap();
1474
1475 assert!(tree.can_support_proposal(test_proposal_type));
1476 assert!(!tree.can_support_proposal(ProposalType::from(43)));
1477
1478 let test_node = get_basic_test_node(TEST_CIPHER_SUITE, "another").await;
1479
1480 tree.add_leaves(
1481 vec![test_node],
1482 &BasicIdentityProvider,
1483 &cipher_suite_provider,
1484 )
1485 .await
1486 .unwrap();
1487
1488 assert!(!tree.can_support_proposal(test_proposal_type));
1489 }
1490 }
1491