1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec;
6 use alloc::vec::Vec;
7 #[cfg(feature = "std")]
8 use core::fmt::Display;
9 use itertools::Itertools;
10 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
11 use mls_rs_core::extension::ExtensionList;
12 
13 use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider};
14 
15 #[cfg(feature = "tree_index")]
16 use mls_rs_core::identity::SigningIdentity;
17 
18 use math as tree_math;
19 use node::{LeafIndex, NodeIndex, NodeVec};
20 
21 use self::leaf_node::LeafNode;
22 
23 use crate::client::MlsError;
24 use crate::crypto::{self, CipherSuiteProvider, HpkeSecretKey};
25 
26 #[cfg(feature = "by_ref_proposal")]
27 use crate::group::proposal::{AddProposal, UpdateProposal};
28 
29 #[cfg(any(test, feature = "by_ref_proposal"))]
30 use crate::group::proposal::RemoveProposal;
31 
32 use crate::group::proposal_filter::ProposalBundle;
33 use crate::tree_kem::tree_hash::TreeHashes;
34 
35 mod capabilities;
36 pub(crate) mod hpke_encryption;
37 mod lifetime;
38 pub(crate) mod math;
39 pub mod node;
40 pub mod parent_hash;
41 pub mod path_secret;
42 mod private;
43 mod tree_hash;
44 pub mod tree_validator;
45 pub mod update_path;
46 
47 pub use capabilities::*;
48 pub use lifetime::*;
49 pub(crate) use private::*;
50 pub use update_path::*;
51 
52 use tree_index::*;
53 
54 pub mod kem;
55 pub mod leaf_node;
56 pub mod leaf_node_validator;
57 mod tree_index;
58 
59 #[cfg(feature = "std")]
60 pub(crate) mod tree_utils;
61 
62 #[cfg(test)]
63 mod interop_test_vectors;
64 
65 #[cfg(feature = "custom_proposal")]
66 use crate::group::proposal::ProposalType;
67 
68 #[derive(Clone, Debug, MlsEncode, MlsDecode, MlsSize, Default)]
69 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
70 pub struct TreeKemPublic {
71     #[cfg(feature = "tree_index")]
72     #[cfg_attr(feature = "serde", serde(skip))]
73     index: TreeIndex,
74     pub(crate) nodes: NodeVec,
75     tree_hashes: TreeHashes,
76 }
77 
78 impl PartialEq for TreeKemPublic {
eq(&self, other: &Self) -> bool79     fn eq(&self, other: &Self) -> bool {
80         self.nodes == other.nodes
81     }
82 }
83 
84 impl TreeKemPublic {
new() -> TreeKemPublic85     pub fn new() -> TreeKemPublic {
86         Default::default()
87     }
88 
89     #[cfg_attr(not(feature = "tree_index"), allow(unused))]
90     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
import_node_data<IP>( nodes: NodeVec, identity_provider: &IP, extensions: &ExtensionList, ) -> Result<TreeKemPublic, MlsError> where IP: IdentityProvider,91     pub(crate) async fn import_node_data<IP>(
92         nodes: NodeVec,
93         identity_provider: &IP,
94         extensions: &ExtensionList,
95     ) -> Result<TreeKemPublic, MlsError>
96     where
97         IP: IdentityProvider,
98     {
99         let mut tree = TreeKemPublic {
100             nodes,
101             ..Default::default()
102         };
103 
104         #[cfg(feature = "tree_index")]
105         tree.initialize_index_if_necessary(identity_provider, extensions)
106             .await?;
107 
108         Ok(tree)
109     }
110 
111     #[cfg(feature = "tree_index")]
112     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
initialize_index_if_necessary<IP: IdentityProvider>( &mut self, identity_provider: &IP, extensions: &ExtensionList, ) -> Result<(), MlsError>113     pub(crate) async fn initialize_index_if_necessary<IP: IdentityProvider>(
114         &mut self,
115         identity_provider: &IP,
116         extensions: &ExtensionList,
117     ) -> Result<(), MlsError> {
118         if !self.index.is_initialized() {
119             self.index = TreeIndex::new();
120 
121             for (leaf_index, leaf) in self.nodes.non_empty_leaves() {
122                 index_insert(
123                     &mut self.index,
124                     leaf,
125                     leaf_index,
126                     identity_provider,
127                     extensions,
128                 )
129                 .await?;
130             }
131         }
132 
133         Ok(())
134     }
135 
136     #[cfg(feature = "tree_index")]
get_leaf_node_with_identity(&self, identity: &[u8]) -> Option<LeafIndex>137     pub(crate) fn get_leaf_node_with_identity(&self, identity: &[u8]) -> Option<LeafIndex> {
138         self.index.get_leaf_index_with_identity(identity)
139     }
140 
141     #[cfg(not(feature = "tree_index"))]
142     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_leaf_node_with_identity<I: IdentityProvider>( &self, identity: &[u8], id_provider: &I, extensions: &ExtensionList, ) -> Result<Option<LeafIndex>, MlsError>143     pub(crate) async fn get_leaf_node_with_identity<I: IdentityProvider>(
144         &self,
145         identity: &[u8],
146         id_provider: &I,
147         extensions: &ExtensionList,
148     ) -> Result<Option<LeafIndex>, MlsError> {
149         for (i, leaf) in self.nodes.non_empty_leaves() {
150             let leaf_id = id_provider
151                 .identity(&leaf.signing_identity, extensions)
152                 .await
153                 .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
154 
155             if leaf_id == identity {
156                 return Ok(Some(i));
157             }
158         }
159 
160         Ok(None)
161     }
162 
163     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
derive<I: IdentityProvider>( leaf_node: LeafNode, secret_key: HpkeSecretKey, identity_provider: &I, extensions: &ExtensionList, ) -> Result<(TreeKemPublic, TreeKemPrivate), MlsError>164     pub async fn derive<I: IdentityProvider>(
165         leaf_node: LeafNode,
166         secret_key: HpkeSecretKey,
167         identity_provider: &I,
168         extensions: &ExtensionList,
169     ) -> Result<(TreeKemPublic, TreeKemPrivate), MlsError> {
170         let mut public_tree = TreeKemPublic::new();
171 
172         public_tree
173             .add_leaf(leaf_node, identity_provider, extensions, None)
174             .await?;
175 
176         let private_tree = TreeKemPrivate::new_self_leaf(LeafIndex(0), secret_key);
177 
178         Ok((public_tree, private_tree))
179     }
180 
total_leaf_count(&self) -> u32181     pub fn total_leaf_count(&self) -> u32 {
182         self.nodes.total_leaf_count()
183     }
184 
185     #[cfg(any(test, all(feature = "custom_proposal", feature = "tree_index")))]
occupied_leaf_count(&self) -> u32186     pub fn occupied_leaf_count(&self) -> u32 {
187         self.nodes.occupied_leaf_count()
188     }
189 
get_leaf_node(&self, index: LeafIndex) -> Result<&LeafNode, MlsError>190     pub fn get_leaf_node(&self, index: LeafIndex) -> Result<&LeafNode, MlsError> {
191         self.nodes.borrow_as_leaf(index)
192     }
193 
find_leaf_node(&self, leaf_node: &LeafNode) -> Option<LeafIndex>194     pub fn find_leaf_node(&self, leaf_node: &LeafNode) -> Option<LeafIndex> {
195         self.nodes.non_empty_leaves().find_map(
196             |(index, node)| {
197                 if node == leaf_node {
198                     Some(index)
199                 } else {
200                     None
201                 }
202             },
203         )
204     }
205 
206     #[cfg(feature = "custom_proposal")]
can_support_proposal(&self, proposal_type: ProposalType) -> bool207     pub fn can_support_proposal(&self, proposal_type: ProposalType) -> bool {
208         #[cfg(feature = "tree_index")]
209         return self.index.count_supporting_proposal(proposal_type) == self.occupied_leaf_count();
210 
211         #[cfg(not(feature = "tree_index"))]
212         self.nodes
213             .non_empty_leaves()
214             .all(|(_, l)| l.capabilities.proposals.contains(&proposal_type))
215     }
216 
217     #[cfg(test)]
218     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
add_leaves<I: IdentityProvider, CP: CipherSuiteProvider>( &mut self, leaf_nodes: Vec<LeafNode>, id_provider: &I, cipher_suite_provider: &CP, ) -> Result<Vec<LeafIndex>, MlsError>219     pub async fn add_leaves<I: IdentityProvider, CP: CipherSuiteProvider>(
220         &mut self,
221         leaf_nodes: Vec<LeafNode>,
222         id_provider: &I,
223         cipher_suite_provider: &CP,
224     ) -> Result<Vec<LeafIndex>, MlsError> {
225         let mut start = LeafIndex(0);
226         let mut added = vec![];
227 
228         for leaf in leaf_nodes.into_iter() {
229             start = self
230                 .add_leaf(leaf, id_provider, &Default::default(), Some(start))
231                 .await?;
232             added.push(start);
233         }
234 
235         self.update_hashes(&added, cipher_suite_provider).await?;
236 
237         Ok(added)
238     }
239 
non_empty_leaves(&self) -> impl Iterator<Item = (LeafIndex, &LeafNode)> + '_240     pub fn non_empty_leaves(&self) -> impl Iterator<Item = (LeafIndex, &LeafNode)> + '_ {
241         self.nodes.non_empty_leaves()
242     }
243 
244     #[cfg(feature = "prior_epoch")]
leaves(&self) -> impl Iterator<Item = Option<&LeafNode>> + '_245     pub fn leaves(&self) -> impl Iterator<Item = Option<&LeafNode>> + '_ {
246         self.nodes.leaves()
247     }
248 
update_node( &mut self, pub_key: crypto::HpkePublicKey, index: NodeIndex, ) -> Result<(), MlsError>249     pub(crate) fn update_node(
250         &mut self,
251         pub_key: crypto::HpkePublicKey,
252         index: NodeIndex,
253     ) -> Result<(), MlsError> {
254         self.nodes
255             .borrow_or_fill_node_as_parent(index, &pub_key)
256             .map(|p| {
257                 p.public_key = pub_key;
258                 p.unmerged_leaves = vec![];
259             })
260     }
261 
262     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
apply_update_path<IP, CP>( &mut self, sender: LeafIndex, update_path: &ValidatedUpdatePath, extensions: &ExtensionList, identity_provider: IP, cipher_suite_provider: &CP, ) -> Result<(), MlsError> where IP: IdentityProvider, CP: CipherSuiteProvider,263     pub(crate) async fn apply_update_path<IP, CP>(
264         &mut self,
265         sender: LeafIndex,
266         update_path: &ValidatedUpdatePath,
267         extensions: &ExtensionList,
268         identity_provider: IP,
269         cipher_suite_provider: &CP,
270     ) -> Result<(), MlsError>
271     where
272         IP: IdentityProvider,
273         CP: CipherSuiteProvider,
274     {
275         // Install the new leaf node
276         let existing_leaf = self.nodes.borrow_as_leaf_mut(sender)?;
277 
278         #[cfg(feature = "tree_index")]
279         let original_leaf_node = existing_leaf.clone();
280 
281         #[cfg(feature = "tree_index")]
282         let original_identity = identity_provider
283             .identity(&original_leaf_node.signing_identity, extensions)
284             .await
285             .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
286 
287         *existing_leaf = update_path.leaf_node.clone();
288 
289         // Update the rest of the nodes on the direct path
290         let path = self.nodes.direct_copath(sender);
291 
292         for (node, pn) in update_path.nodes.iter().zip(path) {
293             node.as_ref()
294                 .map(|n| self.update_node(n.public_key.clone(), pn.path))
295                 .transpose()?;
296         }
297 
298         #[cfg(feature = "tree_index")]
299         self.index.remove(&original_leaf_node, &original_identity);
300 
301         index_insert(
302             #[cfg(feature = "tree_index")]
303             &mut self.index,
304             #[cfg(not(feature = "tree_index"))]
305             &self.nodes,
306             &update_path.leaf_node,
307             sender,
308             &identity_provider,
309             extensions,
310         )
311         .await?;
312 
313         // Verify the parent hash of the new sender leaf node and update the parent hash values
314         // in the local tree
315         self.update_parent_hashes(sender, true, cipher_suite_provider)
316             .await?;
317 
318         Ok(())
319     }
320 
update_unmerged(&mut self, index: LeafIndex) -> Result<(), MlsError>321     fn update_unmerged(&mut self, index: LeafIndex) -> Result<(), MlsError> {
322         // For a given leaf index, find parent nodes and add the leaf to the unmerged leaf
323         self.nodes.direct_copath(index).into_iter().for_each(|i| {
324             if let Ok(p) = self.nodes.borrow_as_parent_mut(i.path) {
325                 p.unmerged_leaves.push(index)
326             }
327         });
328 
329         Ok(())
330     }
331 
332     #[cfg(feature = "by_ref_proposal")]
333     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
batch_edit<I, CP>( &mut self, proposal_bundle: &mut ProposalBundle, extensions: &ExtensionList, id_provider: &I, cipher_suite_provider: &CP, filter: bool, ) -> Result<Vec<LeafIndex>, MlsError> where I: IdentityProvider, CP: CipherSuiteProvider,334     pub async fn batch_edit<I, CP>(
335         &mut self,
336         proposal_bundle: &mut ProposalBundle,
337         extensions: &ExtensionList,
338         id_provider: &I,
339         cipher_suite_provider: &CP,
340         filter: bool,
341     ) -> Result<Vec<LeafIndex>, MlsError>
342     where
343         I: IdentityProvider,
344         CP: CipherSuiteProvider,
345     {
346         // Apply removes (they commute with updates because they don't touch the same leaves)
347         for i in (0..proposal_bundle.remove_proposals().len()).rev() {
348             let index = proposal_bundle.remove_proposals()[i].proposal.to_remove;
349             let res = self.nodes.blank_leaf_node(index);
350 
351             if res.is_ok() {
352                 // This shouldn't fail if `blank_leaf_node` succedded.
353                 self.nodes.blank_direct_path(index)?;
354             }
355 
356             #[cfg(feature = "tree_index")]
357             if let Ok(old_leaf) = &res {
358                 // If this fails, it's not because the proposal is bad.
359                 let identity =
360                     identity(&old_leaf.signing_identity, id_provider, extensions).await?;
361 
362                 self.index.remove(old_leaf, &identity);
363             }
364 
365             if proposal_bundle.remove_proposals()[i].is_by_value() || !filter {
366                 res?;
367             } else if res.is_err() {
368                 proposal_bundle.remove::<RemoveProposal>(i);
369             }
370         }
371 
372         // Remove from the tree old leaves from updates
373         let mut partial_updates = vec![];
374         let senders = proposal_bundle.update_senders.iter().copied();
375 
376         for (i, (p, index)) in proposal_bundle.updates.iter().zip(senders).enumerate() {
377             let new_leaf = p.proposal.leaf_node.clone();
378 
379             match self.nodes.blank_leaf_node(index) {
380                 Ok(old_leaf) => {
381                     #[cfg(feature = "tree_index")]
382                     let old_id =
383                         identity(&old_leaf.signing_identity, id_provider, extensions).await?;
384 
385                     #[cfg(feature = "tree_index")]
386                     self.index.remove(&old_leaf, &old_id);
387 
388                     partial_updates.push((index, old_leaf, new_leaf, i));
389                 }
390                 _ => {
391                     if !filter || !p.is_by_reference() {
392                         return Err(MlsError::UpdatingNonExistingMember);
393                     }
394                 }
395             }
396         }
397 
398         #[cfg(feature = "tree_index")]
399         let index_clone = self.index.clone();
400 
401         let mut removed_leaves = vec![];
402         let mut updated_indices = vec![];
403         let mut bad_indices = vec![];
404 
405         // Apply updates one by one. If there's an update which we can't apply or revert, we revert
406         // all updates.
407         for (index, old_leaf, new_leaf, i) in partial_updates.into_iter() {
408             #[cfg(feature = "tree_index")]
409             let res =
410                 index_insert(&mut self.index, &new_leaf, index, id_provider, extensions).await;
411 
412             #[cfg(not(feature = "tree_index"))]
413             let res = index_insert(&self.nodes, &new_leaf, index, id_provider, extensions).await;
414 
415             let err = res.is_err();
416 
417             if !filter {
418                 res?;
419             }
420 
421             if !err {
422                 self.nodes.insert_leaf(index, new_leaf);
423                 removed_leaves.push(old_leaf);
424                 updated_indices.push(index);
425             } else {
426                 #[cfg(feature = "tree_index")]
427                 let res =
428                     index_insert(&mut self.index, &old_leaf, index, id_provider, extensions).await;
429 
430                 #[cfg(not(feature = "tree_index"))]
431                 let res =
432                     index_insert(&self.nodes, &old_leaf, index, id_provider, extensions).await;
433 
434                 if res.is_ok() {
435                     self.nodes.insert_leaf(index, old_leaf);
436                     bad_indices.push(i);
437                 } else {
438                     // Revert all updates and stop. We're already in the "filter" case, so we don't throw an error.
439                     #[cfg(feature = "tree_index")]
440                     {
441                         self.index = index_clone;
442                     }
443 
444                     removed_leaves
445                         .into_iter()
446                         .zip(updated_indices.iter())
447                         .for_each(|(leaf, index)| self.nodes.insert_leaf(*index, leaf));
448 
449                     updated_indices = vec![];
450                     break;
451                 }
452             }
453         }
454 
455         // If we managed to update something, blank direct paths
456         updated_indices
457             .iter()
458             .try_for_each(|index| self.nodes.blank_direct_path(*index).map(|_| ()))?;
459 
460         // Remove rejected updates from applied proposals
461         if updated_indices.is_empty() {
462             // This takes care of the "revert all" scenario
463             proposal_bundle.updates = vec![];
464         } else {
465             for i in bad_indices.into_iter().rev() {
466                 proposal_bundle.remove::<UpdateProposal>(i);
467                 proposal_bundle.update_senders.remove(i);
468             }
469         }
470 
471         // Apply adds
472         let mut start = LeafIndex(0);
473         let mut added = vec![];
474         let mut bad_indexes = vec![];
475 
476         for i in 0..proposal_bundle.additions.len() {
477             let leaf = proposal_bundle.additions[i]
478                 .proposal
479                 .key_package
480                 .leaf_node
481                 .clone();
482 
483             let res = self
484                 .add_leaf(leaf, id_provider, extensions, Some(start))
485                 .await;
486 
487             if let Ok(index) = res {
488                 start = index;
489                 added.push(start);
490             } else if proposal_bundle.additions[i].is_by_value() || !filter {
491                 res?;
492             } else {
493                 bad_indexes.push(i);
494             }
495         }
496 
497         for i in bad_indexes.into_iter().rev() {
498             proposal_bundle.remove::<AddProposal>(i);
499         }
500 
501         self.nodes.trim();
502 
503         let updated_leaves = proposal_bundle
504             .remove_proposals()
505             .iter()
506             .map(|p| p.proposal.to_remove)
507             .chain(updated_indices)
508             .chain(added.iter().copied())
509             .collect_vec();
510 
511         self.update_hashes(&updated_leaves, cipher_suite_provider)
512             .await?;
513 
514         Ok(added)
515     }
516 
517     #[cfg(not(feature = "by_ref_proposal"))]
518     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
batch_edit_lite<I, CP>( &mut self, proposal_bundle: &ProposalBundle, extensions: &ExtensionList, id_provider: &I, cipher_suite_provider: &CP, ) -> Result<Vec<LeafIndex>, MlsError> where I: IdentityProvider, CP: CipherSuiteProvider,519     pub async fn batch_edit_lite<I, CP>(
520         &mut self,
521         proposal_bundle: &ProposalBundle,
522         extensions: &ExtensionList,
523         id_provider: &I,
524         cipher_suite_provider: &CP,
525     ) -> Result<Vec<LeafIndex>, MlsError>
526     where
527         I: IdentityProvider,
528         CP: CipherSuiteProvider,
529     {
530         // Apply removes
531         for p in &proposal_bundle.removals {
532             let index = p.proposal.to_remove;
533 
534             #[cfg(feature = "tree_index")]
535             {
536                 // If this fails, it's not because the proposal is bad.
537                 let old_leaf = self.nodes.blank_leaf_node(index)?;
538 
539                 let identity =
540                     identity(&old_leaf.signing_identity, id_provider, extensions).await?;
541 
542                 self.index.remove(&old_leaf, &identity);
543             }
544 
545             #[cfg(not(feature = "tree_index"))]
546             self.nodes.blank_leaf_node(index)?;
547 
548             self.nodes.blank_direct_path(index)?;
549         }
550 
551         // Apply adds
552         let mut start = LeafIndex(0);
553         let mut added = vec![];
554 
555         for p in &proposal_bundle.additions {
556             let leaf = p.proposal.key_package.leaf_node.clone();
557             start = self
558                 .add_leaf(leaf, id_provider, extensions, Some(start))
559                 .await?;
560             added.push(start);
561         }
562 
563         self.nodes.trim();
564 
565         let updated_leaves = proposal_bundle
566             .remove_proposals()
567             .iter()
568             .map(|p| p.proposal.to_remove)
569             .chain(added.iter().copied())
570             .collect_vec();
571 
572         self.update_hashes(&updated_leaves, cipher_suite_provider)
573             .await?;
574 
575         Ok(added)
576     }
577 
578     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
add_leaf<I: IdentityProvider>( &mut self, leaf: LeafNode, id_provider: &I, extensions: &ExtensionList, start: Option<LeafIndex>, ) -> Result<LeafIndex, MlsError>579     pub(crate) async fn add_leaf<I: IdentityProvider>(
580         &mut self,
581         leaf: LeafNode,
582         id_provider: &I,
583         extensions: &ExtensionList,
584         start: Option<LeafIndex>,
585     ) -> Result<LeafIndex, MlsError> {
586         let index = self.nodes.next_empty_leaf(start.unwrap_or(LeafIndex(0)));
587 
588         #[cfg(feature = "tree_index")]
589         index_insert(&mut self.index, &leaf, index, id_provider, extensions).await?;
590 
591         #[cfg(not(feature = "tree_index"))]
592         index_insert(&self.nodes, &leaf, index, id_provider, extensions).await?;
593 
594         self.nodes.insert_leaf(index, leaf);
595         self.update_unmerged(index)?;
596 
597         Ok(index)
598     }
599 }
600 
601 #[cfg(feature = "tree_index")]
602 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
identity<I: IdentityProvider>( signing_id: &SigningIdentity, provider: &I, extensions: &ExtensionList, ) -> Result<Vec<u8>, MlsError>603 async fn identity<I: IdentityProvider>(
604     signing_id: &SigningIdentity,
605     provider: &I,
606     extensions: &ExtensionList,
607 ) -> Result<Vec<u8>, MlsError> {
608     provider
609         .identity(signing_id, extensions)
610         .await
611         .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))
612 }
613 
614 #[cfg(feature = "std")]
615 impl Display for TreeKemPublic {
fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result616     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
617         write!(f, "{}", tree_utils::build_ascii_tree(&self.nodes))
618     }
619 }
620 
621 #[cfg(test)]
622 use crate::group::{proposal::Proposal, proposal_filter::ProposalSource, Sender};
623 
624 #[cfg(test)]
625 impl TreeKemPublic {
626     #[cfg(feature = "by_ref_proposal")]
627     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
update_leaf<I, CP>( &mut self, leaf_index: u32, leaf_node: LeafNode, identity_provider: &I, cipher_suite_provider: &CP, ) -> Result<(), MlsError> where I: IdentityProvider, CP: CipherSuiteProvider,628     pub async fn update_leaf<I, CP>(
629         &mut self,
630         leaf_index: u32,
631         leaf_node: LeafNode,
632         identity_provider: &I,
633         cipher_suite_provider: &CP,
634     ) -> Result<(), MlsError>
635     where
636         I: IdentityProvider,
637         CP: CipherSuiteProvider,
638     {
639         let p = Proposal::Update(UpdateProposal { leaf_node });
640 
641         let mut bundle = ProposalBundle::default();
642         bundle.add(p, Sender::Member(leaf_index), ProposalSource::ByValue);
643         bundle.update_senders = vec![LeafIndex(leaf_index)];
644 
645         self.batch_edit(
646             &mut bundle,
647             &Default::default(),
648             identity_provider,
649             cipher_suite_provider,
650             true,
651         )
652         .await?;
653 
654         Ok(())
655     }
656 
657     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
remove_leaves<I, CP>( &mut self, indexes: Vec<LeafIndex>, identity_provider: &I, cipher_suite_provider: &CP, ) -> Result<Vec<(LeafIndex, LeafNode)>, MlsError> where I: IdentityProvider, CP: CipherSuiteProvider,658     pub async fn remove_leaves<I, CP>(
659         &mut self,
660         indexes: Vec<LeafIndex>,
661         identity_provider: &I,
662         cipher_suite_provider: &CP,
663     ) -> Result<Vec<(LeafIndex, LeafNode)>, MlsError>
664     where
665         I: IdentityProvider,
666         CP: CipherSuiteProvider,
667     {
668         let old_tree = self.clone();
669 
670         let proposals = indexes
671             .iter()
672             .copied()
673             .map(|to_remove| Proposal::Remove(RemoveProposal { to_remove }));
674 
675         let mut bundle = ProposalBundle::default();
676 
677         for p in proposals {
678             bundle.add(p, Sender::Member(0), ProposalSource::ByValue);
679         }
680 
681         #[cfg(feature = "by_ref_proposal")]
682         self.batch_edit(
683             &mut bundle,
684             &Default::default(),
685             identity_provider,
686             cipher_suite_provider,
687             true,
688         )
689         .await?;
690 
691         #[cfg(not(feature = "by_ref_proposal"))]
692         self.batch_edit_lite(
693             &bundle,
694             &Default::default(),
695             identity_provider,
696             cipher_suite_provider,
697         )
698         .await?;
699 
700         bundle
701             .removals
702             .iter()
703             .map(|p| {
704                 let index = p.proposal.to_remove;
705                 let leaf = old_tree.get_leaf_node(index)?.clone();
706                 Ok((index, leaf))
707             })
708             .collect()
709     }
710 
get_leaf_nodes(&self) -> Vec<&LeafNode>711     pub fn get_leaf_nodes(&self) -> Vec<&LeafNode> {
712         self.nodes.non_empty_leaves().map(|(_, l)| l).collect()
713     }
714 }
715 
716 #[cfg(test)]
717 pub(crate) mod test_utils {
718     use crate::crypto::test_utils::TestCryptoProvider;
719     use crate::signer::Signable;
720     use alloc::vec::Vec;
721     use alloc::{format, vec};
722     use mls_rs_core::crypto::CipherSuiteProvider;
723     use mls_rs_core::group::Capabilities;
724     use mls_rs_core::identity::BasicCredential;
725 
726     use crate::identity::test_utils::get_test_signing_identity;
727     use crate::{
728         cipher_suite::CipherSuite,
729         crypto::{HpkeSecretKey, SignatureSecretKey},
730         identity::basic::BasicIdentityProvider,
731         tree_kem::leaf_node::test_utils::get_basic_test_node_sig_key,
732     };
733 
734     use super::leaf_node::{ConfigProperties, LeafNodeSigningContext};
735     use super::node::LeafIndex;
736     use super::Lifetime;
737     use super::{
738         leaf_node::{test_utils::get_basic_test_node, LeafNode},
739         TreeKemPrivate, TreeKemPublic,
740     };
741 
742     #[derive(Debug)]
743     pub(crate) struct TestTree {
744         pub public: TreeKemPublic,
745         pub private: TreeKemPrivate,
746         pub creator_leaf: LeafNode,
747         pub creator_signing_key: SignatureSecretKey,
748         pub creator_hpke_secret: HpkeSecretKey,
749     }
750 
751     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_tree(cipher_suite: CipherSuite) -> TestTree752     pub(crate) async fn get_test_tree(cipher_suite: CipherSuite) -> TestTree {
753         let (creator_leaf, creator_hpke_secret, creator_signing_key) =
754             get_basic_test_node_sig_key(cipher_suite, "creator").await;
755 
756         let (test_public, test_private) = TreeKemPublic::derive(
757             creator_leaf.clone(),
758             creator_hpke_secret.clone(),
759             &BasicIdentityProvider,
760             &Default::default(),
761         )
762         .await
763         .unwrap();
764 
765         TestTree {
766             public: test_public,
767             private: test_private,
768             creator_leaf,
769             creator_signing_key,
770             creator_hpke_secret,
771         }
772     }
773 
774     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_leaf_nodes(cipher_suite: CipherSuite) -> Vec<LeafNode>775     pub async fn get_test_leaf_nodes(cipher_suite: CipherSuite) -> Vec<LeafNode> {
776         [
777             get_basic_test_node(cipher_suite, "A").await,
778             get_basic_test_node(cipher_suite, "B").await,
779             get_basic_test_node(cipher_suite, "C").await,
780         ]
781         .to_vec()
782     }
783 
784     impl TreeKemPublic {
785         #[cfg(feature = "tree_index")]
equal_internals(&self, other: &TreeKemPublic) -> bool786         pub fn equal_internals(&self, other: &TreeKemPublic) -> bool {
787             self.tree_hashes == other.tree_hashes && self.index == other.index
788         }
789     }
790 
791     #[derive(Debug, Clone)]
792     pub struct TreeWithSigners {
793         pub tree: TreeKemPublic,
794         pub signers: Vec<Option<SignatureSecretKey>>,
795         pub group_id: Vec<u8>,
796     }
797 
798     impl TreeWithSigners {
799         #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
make_full_tree<P: CipherSuiteProvider>( n_leaves: u32, cs: &P, ) -> TreeWithSigners800         pub async fn make_full_tree<P: CipherSuiteProvider>(
801             n_leaves: u32,
802             cs: &P,
803         ) -> TreeWithSigners {
804             let mut tree = TreeWithSigners {
805                 tree: TreeKemPublic::new(),
806                 signers: vec![],
807                 group_id: cs.random_bytes_vec(cs.kdf_extract_size()).unwrap(),
808             };
809 
810             tree.add_member("Alice", cs).await;
811 
812             // A adds B, B adds C, C adds D etc.
813             for i in 1..n_leaves {
814                 tree.add_member(&format!("Alice{i}"), cs).await;
815                 tree.update_committer_path(i - 1, cs).await;
816             }
817 
818             tree
819         }
820 
821         #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
add_member<P: CipherSuiteProvider>(&mut self, name: &str, cs: &P)822         pub async fn add_member<P: CipherSuiteProvider>(&mut self, name: &str, cs: &P) {
823             let (leaf, signer) = make_leaf(name, cs).await;
824             let index = self.tree.nodes.next_empty_leaf(LeafIndex(0));
825             self.tree.nodes.insert_leaf(index, leaf);
826             self.tree.update_unmerged(index).unwrap();
827             let index = *index as usize;
828 
829             match self.signers.len() {
830                 l if l == index => self.signers.push(Some(signer)),
831                 l if l > index => self.signers[index] = Some(signer),
832                 _ => panic!("signer tree size mismatch"),
833             }
834         }
835 
836         #[cfg(feature = "rfc_compliant")]
837         #[cfg_attr(coverage_nightly, coverage(off))]
remove_member(&mut self, member: u32)838         pub fn remove_member(&mut self, member: u32) {
839             self.tree
840                 .nodes
841                 .blank_direct_path(LeafIndex(member))
842                 .unwrap();
843 
844             self.tree.nodes.blank_leaf_node(LeafIndex(member)).unwrap();
845 
846             *self
847                 .signers
848                 .get_mut(member as usize)
849                 .expect("signer tree size mismatch") = None;
850         }
851 
852         #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
update_committer_path<P: CipherSuiteProvider>( &mut self, committer: u32, cs: &P, )853         pub async fn update_committer_path<P: CipherSuiteProvider>(
854             &mut self,
855             committer: u32,
856             cs: &P,
857         ) {
858             let committer = LeafIndex(committer);
859 
860             let path = self.tree.nodes.direct_copath(committer);
861             let filtered = self.tree.nodes.filtered(committer).unwrap();
862 
863             for (n, f) in path.into_iter().zip(filtered) {
864                 if !f {
865                     self.tree
866                         .update_node(cs.kem_generate().await.unwrap().1, n.path)
867                         .unwrap();
868                 }
869             }
870 
871             self.tree.tree_hashes.current = vec![];
872             self.tree.tree_hash(cs).await.unwrap();
873 
874             self.tree
875                 .update_parent_hashes(committer, false, cs)
876                 .await
877                 .unwrap();
878 
879             self.tree.tree_hashes.current = vec![];
880             self.tree.tree_hash(cs).await.unwrap();
881 
882             let context = LeafNodeSigningContext {
883                 group_id: Some(&self.group_id),
884                 leaf_index: Some(*committer),
885             };
886 
887             let signer = self.signers[*committer as usize].as_ref().unwrap();
888 
889             self.tree
890                 .nodes
891                 .borrow_as_leaf_mut(committer)
892                 .unwrap()
893                 .sign(cs, signer, &context)
894                 .await
895                 .unwrap();
896 
897             self.tree.tree_hashes.current = vec![];
898             self.tree.tree_hash(cs).await.unwrap();
899         }
900     }
901 
902     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
make_leaf<P: CipherSuiteProvider>( name: &str, cs: &P, ) -> (LeafNode, SignatureSecretKey)903     pub async fn make_leaf<P: CipherSuiteProvider>(
904         name: &str,
905         cs: &P,
906     ) -> (LeafNode, SignatureSecretKey) {
907         let (signing_identity, signature_key) =
908             get_test_signing_identity(cs.cipher_suite(), name.as_bytes()).await;
909 
910         let capabilities = Capabilities {
911             credentials: vec![BasicCredential::credential_type()],
912             cipher_suites: TestCryptoProvider::all_supported_cipher_suites(),
913             ..Default::default()
914         };
915 
916         let properties = ConfigProperties {
917             capabilities,
918             extensions: Default::default(),
919         };
920 
921         let (leaf, _) = LeafNode::generate(
922             cs,
923             properties,
924             signing_identity,
925             &signature_key,
926             Lifetime::years(1).unwrap(),
927         )
928         .await
929         .unwrap();
930 
931         (leaf, signature_key)
932     }
933 }
934 
935 #[cfg(test)]
936 mod tests {
937     use crate::client::test_utils::TEST_CIPHER_SUITE;
938     use crate::crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider};
939 
940     #[cfg(feature = "custom_proposal")]
941     use crate::group::proposal::ProposalType;
942 
943     use crate::identity::basic::BasicIdentityProvider;
944     use crate::tree_kem::leaf_node::LeafNode;
945     use crate::tree_kem::node::{LeafIndex, Node, NodeIndex, NodeTypeResolver, Parent};
946     use crate::tree_kem::parent_hash::ParentHash;
947     use crate::tree_kem::test_utils::{get_test_leaf_nodes, get_test_tree};
948     use crate::tree_kem::{MlsError, TreeKemPublic};
949     use alloc::borrow::ToOwned;
950     use alloc::vec;
951     use alloc::vec::Vec;
952     use assert_matches::assert_matches;
953 
954     #[cfg(feature = "by_ref_proposal")]
955     use alloc::boxed::Box;
956 
957     #[cfg(feature = "by_ref_proposal")]
958     use crate::{
959         client::test_utils::TEST_PROTOCOL_VERSION,
960         group::{
961             proposal::{Proposal, RemoveProposal, UpdateProposal},
962             proposal_filter::{ProposalBundle, ProposalSource},
963             proposal_ref::ProposalRef,
964             Sender,
965         },
966         key_package::test_utils::test_key_package,
967     };
968 
969     #[cfg(any(feature = "by_ref_proposal", feature = "custo_proposal"))]
970     use crate::tree_kem::leaf_node::test_utils::get_basic_test_node;
971 
972     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_derive()973     async fn test_derive() {
974         for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
975             let test_tree = get_test_tree(cipher_suite).await;
976 
977             assert_eq!(
978                 test_tree.public.nodes[0],
979                 Some(Node::Leaf(test_tree.creator_leaf.clone()))
980             );
981 
982             assert_eq!(test_tree.private.self_index, LeafIndex(0));
983 
984             assert_eq!(
985                 test_tree.private.secret_keys[0],
986                 Some(test_tree.creator_hpke_secret)
987             );
988         }
989     }
990 
991     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_import_export()992     async fn test_import_export() {
993         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
994         let mut test_tree = get_test_tree(TEST_CIPHER_SUITE).await;
995 
996         let additional_key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
997 
998         test_tree
999             .public
1000             .add_leaves(
1001                 additional_key_packages,
1002                 &BasicIdentityProvider,
1003                 &cipher_suite_provider,
1004             )
1005             .await
1006             .unwrap();
1007 
1008         let imported = TreeKemPublic::import_node_data(
1009             test_tree.public.nodes.clone(),
1010             &BasicIdentityProvider,
1011             &Default::default(),
1012         )
1013         .await
1014         .unwrap();
1015 
1016         assert_eq!(test_tree.public.nodes, imported.nodes);
1017 
1018         #[cfg(feature = "tree_index")]
1019         assert_eq!(test_tree.public.index, imported.index);
1020     }
1021 
1022     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_add_leaf()1023     async fn test_add_leaf() {
1024         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1025         let mut tree = TreeKemPublic::new();
1026 
1027         let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1028 
1029         let res = tree
1030             .add_leaves(
1031                 leaf_nodes.clone(),
1032                 &BasicIdentityProvider,
1033                 &cipher_suite_provider,
1034             )
1035             .await
1036             .unwrap();
1037 
1038         // The leaf count should be equal to the number of packages we added
1039         assert_eq!(res.len(), leaf_nodes.len());
1040         assert_eq!(tree.occupied_leaf_count(), leaf_nodes.len() as u32);
1041 
1042         // Each added package should be at the proper index and searchable in the tree
1043         res.into_iter().zip(leaf_nodes.clone()).for_each(|(r, kp)| {
1044             assert_eq!(tree.get_leaf_node(r).unwrap(), &kp);
1045         });
1046 
1047         // Verify the underlying state
1048         #[cfg(feature = "tree_index")]
1049         assert_eq!(tree.index.len(), tree.occupied_leaf_count() as usize);
1050 
1051         assert_eq!(tree.nodes.len(), 5);
1052         assert_eq!(tree.nodes[0], leaf_nodes[0].clone().into());
1053         assert_eq!(tree.nodes[1], None);
1054         assert_eq!(tree.nodes[2], leaf_nodes[1].clone().into());
1055         assert_eq!(tree.nodes[3], None);
1056         assert_eq!(tree.nodes[4], leaf_nodes[2].clone().into());
1057     }
1058 
1059     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_get_key_packages()1060     async fn test_get_key_packages() {
1061         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1062         let mut tree = TreeKemPublic::new();
1063 
1064         let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1065 
1066         tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
1067             .await
1068             .unwrap();
1069 
1070         let key_packages = tree.get_leaf_nodes();
1071         assert_eq!(key_packages, key_packages.to_owned());
1072     }
1073 
1074     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_add_leaf_duplicate()1075     async fn test_add_leaf_duplicate() {
1076         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1077         let mut tree = TreeKemPublic::new();
1078 
1079         let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1080 
1081         tree.add_leaves(
1082             key_packages.clone(),
1083             &BasicIdentityProvider,
1084             &cipher_suite_provider,
1085         )
1086         .await
1087         .unwrap();
1088 
1089         let res = tree
1090             .add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
1091             .await;
1092 
1093         assert_matches!(res, Err(MlsError::DuplicateLeafData(_)));
1094     }
1095 
1096     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_add_leaf_empty_leaf()1097     async fn test_add_leaf_empty_leaf() {
1098         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1099         let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1100         let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1101 
1102         tree.add_leaves(
1103             [key_packages[0].clone()].to_vec(),
1104             &BasicIdentityProvider,
1105             &cipher_suite_provider,
1106         )
1107         .await
1108         .unwrap();
1109 
1110         tree.nodes[0] = None; // Set the original first node to none
1111                               //
1112         tree.add_leaves(
1113             [key_packages[1].clone()].to_vec(),
1114             &BasicIdentityProvider,
1115             &cipher_suite_provider,
1116         )
1117         .await
1118         .unwrap();
1119 
1120         assert_eq!(tree.nodes[0], key_packages[1].clone().into());
1121         assert_eq!(tree.nodes[1], None);
1122         assert_eq!(tree.nodes[2], key_packages[0].clone().into());
1123         assert_eq!(tree.nodes.len(), 3)
1124     }
1125 
1126     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_add_leaf_unmerged()1127     async fn test_add_leaf_unmerged() {
1128         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1129         let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1130         let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1131 
1132         tree.add_leaves(
1133             [key_packages[0].clone(), key_packages[1].clone()].to_vec(),
1134             &BasicIdentityProvider,
1135             &cipher_suite_provider,
1136         )
1137         .await
1138         .unwrap();
1139 
1140         tree.nodes[3] = Parent {
1141             public_key: vec![].into(),
1142             parent_hash: ParentHash::empty(),
1143             unmerged_leaves: vec![],
1144         }
1145         .into();
1146 
1147         tree.add_leaves(
1148             [key_packages[2].clone()].to_vec(),
1149             &BasicIdentityProvider,
1150             &cipher_suite_provider,
1151         )
1152         .await
1153         .unwrap();
1154 
1155         assert_eq!(
1156             tree.nodes[3].as_parent().unwrap().unmerged_leaves,
1157             vec![LeafIndex(3)]
1158         )
1159     }
1160 
1161     #[cfg(feature = "by_ref_proposal")]
1162     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_update_leaf()1163     async fn test_update_leaf() {
1164         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1165         // Create a tree
1166         let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1167 
1168         let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1169 
1170         tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
1171             .await
1172             .unwrap();
1173 
1174         // Add in parent nodes so we can detect them clearing after update
1175         tree.nodes.direct_copath(LeafIndex(0)).iter().for_each(|n| {
1176             tree.nodes
1177                 .borrow_or_fill_node_as_parent(n.path, &b"pub_key".to_vec().into())
1178                 .unwrap();
1179         });
1180 
1181         let original_size = tree.occupied_leaf_count();
1182         let original_leaf_index = LeafIndex(1);
1183 
1184         let updated_leaf = get_basic_test_node(TEST_CIPHER_SUITE, "A").await;
1185 
1186         tree.update_leaf(
1187             *original_leaf_index,
1188             updated_leaf.clone(),
1189             &BasicIdentityProvider,
1190             &cipher_suite_provider,
1191         )
1192         .await
1193         .unwrap();
1194 
1195         // The tree should not have grown due to an update
1196         assert_eq!(tree.occupied_leaf_count(), original_size);
1197 
1198         // The cache of tree package indexes should not have grown
1199         #[cfg(feature = "tree_index")]
1200         assert_eq!(tree.index.len() as u32, tree.occupied_leaf_count());
1201 
1202         // The key package should be updated in the tree
1203         assert_eq!(
1204             tree.get_leaf_node(original_leaf_index).unwrap(),
1205             &updated_leaf
1206         );
1207 
1208         // Verify that the direct path has been cleared
1209         tree.nodes.direct_copath(LeafIndex(0)).iter().for_each(|n| {
1210             assert!(tree.nodes[n.path as usize].is_none());
1211         });
1212     }
1213 
1214     #[cfg(feature = "by_ref_proposal")]
1215     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_update_leaf_not_found()1216     async fn test_update_leaf_not_found() {
1217         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1218 
1219         // Create a tree
1220         let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1221 
1222         let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1223 
1224         tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
1225             .await
1226             .unwrap();
1227 
1228         let new_key_package = get_basic_test_node(TEST_CIPHER_SUITE, "new").await;
1229 
1230         let res = tree
1231             .update_leaf(
1232                 128,
1233                 new_key_package,
1234                 &BasicIdentityProvider,
1235                 &cipher_suite_provider,
1236             )
1237             .await;
1238 
1239         assert_matches!(res, Err(MlsError::UpdatingNonExistingMember));
1240     }
1241 
1242     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_remove_leaf()1243     async fn test_remove_leaf() {
1244         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1245 
1246         // Create a tree
1247         let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1248         let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1249 
1250         let indexes = tree
1251             .add_leaves(
1252                 key_packages.clone(),
1253                 &BasicIdentityProvider,
1254                 &cipher_suite_provider,
1255             )
1256             .await
1257             .unwrap();
1258 
1259         let original_leaf_count = tree.occupied_leaf_count();
1260 
1261         // Remove two leaves from the tree
1262         let expected_result: Vec<(LeafIndex, LeafNode)> =
1263             indexes.clone().into_iter().zip(key_packages).collect();
1264 
1265         let res = tree
1266             .remove_leaves(
1267                 indexes.clone(),
1268                 &BasicIdentityProvider,
1269                 &cipher_suite_provider,
1270             )
1271             .await
1272             .unwrap();
1273 
1274         // The order may change
1275         assert!(res.iter().all(|x| expected_result.contains(x)));
1276         assert!(expected_result.iter().all(|x| res.contains(x)));
1277 
1278         // The leaves should be removed from the tree
1279         assert_eq!(
1280             tree.occupied_leaf_count(),
1281             original_leaf_count - indexes.len() as u32
1282         );
1283     }
1284 
1285     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_remove_leaf_middle()1286     async fn test_remove_leaf_middle() {
1287         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1288 
1289         // Create a tree
1290         let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1291         let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1292 
1293         let to_remove = tree
1294             .add_leaves(
1295                 leaf_nodes.clone(),
1296                 &BasicIdentityProvider,
1297                 &cipher_suite_provider,
1298             )
1299             .await
1300             .unwrap()[0];
1301 
1302         let original_leaf_count = tree.occupied_leaf_count();
1303 
1304         let res = tree
1305             .remove_leaves(
1306                 vec![to_remove],
1307                 &BasicIdentityProvider,
1308                 &cipher_suite_provider,
1309             )
1310             .await
1311             .unwrap();
1312 
1313         assert_eq!(res, vec![(to_remove, leaf_nodes[0].clone())]);
1314 
1315         // The leaf count should have been reduced by 1
1316         assert_eq!(tree.occupied_leaf_count(), original_leaf_count - 1);
1317 
1318         // There should be a blank in the tree
1319         assert_eq!(
1320             tree.nodes.get(NodeIndex::from(to_remove) as usize).unwrap(),
1321             &None
1322         );
1323     }
1324 
1325     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_create_blanks()1326     async fn test_create_blanks() {
1327         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1328 
1329         // Create a tree
1330         let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1331 
1332         let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1333 
1334         tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
1335             .await
1336             .unwrap();
1337 
1338         let original_leaf_count = tree.occupied_leaf_count();
1339 
1340         let to_remove = vec![LeafIndex(2)];
1341 
1342         // Remove the leaf from the tree
1343         tree.remove_leaves(to_remove, &BasicIdentityProvider, &cipher_suite_provider)
1344             .await
1345             .unwrap();
1346 
1347         // The occupied leaf count should have been reduced by 1
1348         assert_eq!(tree.occupied_leaf_count(), original_leaf_count - 1);
1349 
1350         // The total leaf count should remain unchanged
1351         assert_eq!(tree.total_leaf_count(), original_leaf_count);
1352 
1353         // The location of key_packages[1] should now be blank
1354         let removed_location = tree
1355             .nodes
1356             .get(NodeIndex::from(LeafIndex(2)) as usize)
1357             .unwrap();
1358 
1359         assert_eq!(removed_location, &None);
1360     }
1361 
1362     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_remove_leaf_failure()1363     async fn test_remove_leaf_failure() {
1364         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1365 
1366         // Create a tree
1367         let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1368 
1369         let res = tree
1370             .remove_leaves(
1371                 vec![LeafIndex(128)],
1372                 &BasicIdentityProvider,
1373                 &cipher_suite_provider,
1374             )
1375             .await;
1376 
1377         assert_matches!(res, Err(MlsError::InvalidNodeIndex(256)));
1378     }
1379 
1380     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_find_leaf_node()1381     async fn test_find_leaf_node() {
1382         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1383         // Create a tree
1384         let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1385 
1386         let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1387 
1388         tree.add_leaves(
1389             leaf_nodes.clone(),
1390             &BasicIdentityProvider,
1391             &cipher_suite_provider,
1392         )
1393         .await
1394         .unwrap();
1395 
1396         // Find each node
1397         for (i, leaf_node) in leaf_nodes.iter().enumerate() {
1398             let expected_index = LeafIndex(i as u32 + 1);
1399             assert_eq!(tree.find_leaf_node(leaf_node), Some(expected_index));
1400         }
1401     }
1402 
1403     // TODO add test for the lite version
1404     #[cfg(feature = "by_ref_proposal")]
1405     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
batch_edit_works()1406     async fn batch_edit_works() {
1407         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1408 
1409         let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
1410         let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1411 
1412         tree.add_leaves(leaf_nodes, &BasicIdentityProvider, &cipher_suite_provider)
1413             .await
1414             .unwrap();
1415 
1416         let mut bundle = ProposalBundle::default();
1417 
1418         let kp = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "D").await;
1419         let add = Proposal::Add(Box::new(kp.into()));
1420 
1421         bundle.add(add, Sender::Member(0), ProposalSource::ByValue);
1422 
1423         let update = UpdateProposal {
1424             leaf_node: get_basic_test_node(TEST_CIPHER_SUITE, "A").await,
1425         };
1426 
1427         let update = Proposal::Update(update);
1428         let pref = ProposalRef::new_fake(vec![1, 2, 3]);
1429 
1430         bundle.add(update, Sender::Member(1), ProposalSource::ByReference(pref));
1431 
1432         bundle.update_senders = vec![LeafIndex(1)];
1433 
1434         let remove = RemoveProposal {
1435             to_remove: LeafIndex(2),
1436         };
1437 
1438         let remove = Proposal::Remove(remove);
1439 
1440         bundle.add(remove, Sender::Member(0), ProposalSource::ByValue);
1441 
1442         tree.batch_edit(
1443             &mut bundle,
1444             &Default::default(),
1445             &BasicIdentityProvider,
1446             &cipher_suite_provider,
1447             true,
1448         )
1449         .await
1450         .unwrap();
1451 
1452         assert_eq!(bundle.add_proposals().len(), 1);
1453         assert_eq!(bundle.remove_proposals().len(), 1);
1454         assert_eq!(bundle.update_proposals().len(), 1);
1455     }
1456 
1457     #[cfg(feature = "custom_proposal")]
1458     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
custom_proposal_support()1459     async fn custom_proposal_support() {
1460         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1461         let mut tree = TreeKemPublic::new();
1462 
1463         let test_proposal_type = ProposalType::from(42);
1464 
1465         let mut leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
1466 
1467         leaf_nodes
1468             .iter_mut()
1469             .for_each(|n| n.capabilities.proposals.push(test_proposal_type));
1470 
1471         tree.add_leaves(leaf_nodes, &BasicIdentityProvider, &cipher_suite_provider)
1472             .await
1473             .unwrap();
1474 
1475         assert!(tree.can_support_proposal(test_proposal_type));
1476         assert!(!tree.can_support_proposal(ProposalType::from(43)));
1477 
1478         let test_node = get_basic_test_node(TEST_CIPHER_SUITE, "another").await;
1479 
1480         tree.add_leaves(
1481             vec![test_node],
1482             &BasicIdentityProvider,
1483             &cipher_suite_provider,
1484         )
1485         .await
1486         .unwrap();
1487 
1488         assert!(!tree.can_support_proposal(test_proposal_type));
1489     }
1490 }
1491