1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec::Vec;
6 use core::{fmt::Debug, hash::Hash};
7 use mls_rs_codec::{MlsDecode, MlsEncode};
8 
9 use super::node::LeafIndex;
10 
11 pub trait TreeIndex:
12     Send + Sync + Eq + Clone + Debug + Default + MlsEncode + MlsDecode + Hash + Ord
13 {
root(&self) -> Self14     fn root(&self) -> Self;
15 
left_unchecked(&self) -> Self16     fn left_unchecked(&self) -> Self;
right_unchecked(&self) -> Self17     fn right_unchecked(&self) -> Self;
18 
parent_sibling(&self, leaf_count: &Self) -> Option<ParentSibling<Self>>19     fn parent_sibling(&self, leaf_count: &Self) -> Option<ParentSibling<Self>>;
is_leaf(&self) -> bool20     fn is_leaf(&self) -> bool;
is_in_tree(&self, root: &Self) -> bool21     fn is_in_tree(&self, root: &Self) -> bool;
22 
23     #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
zero() -> Self24     fn zero() -> Self;
25 
26     #[cfg(any(feature = "secret_tree_access", feature = "private_message", test))]
left(&self) -> Option<Self>27     fn left(&self) -> Option<Self> {
28         (!self.is_leaf()).then(|| self.left_unchecked())
29     }
30 
31     #[cfg(any(feature = "secret_tree_access", feature = "private_message", test))]
right(&self) -> Option<Self>32     fn right(&self) -> Option<Self> {
33         (!self.is_leaf()).then(|| self.right_unchecked())
34     }
35 
direct_copath(&self, leaf_count: &Self) -> Vec<CopathNode<Self>>36     fn direct_copath(&self, leaf_count: &Self) -> Vec<CopathNode<Self>> {
37         let root = leaf_count.root();
38 
39         if !self.is_in_tree(&root) {
40             return Vec::new();
41         }
42 
43         let mut path = Vec::new();
44         let mut parent = self.clone();
45 
46         while let Some(ps) = parent.parent_sibling(leaf_count) {
47             path.push(CopathNode::new(ps.parent.clone(), ps.sibling));
48             parent = ps.parent;
49         }
50 
51         path
52     }
53 }
54 
55 #[derive(Clone, PartialEq, Eq, Debug)]
56 pub struct CopathNode<T> {
57     pub path: T,
58     pub copath: T,
59 }
60 
61 impl<T: Clone + PartialEq + Eq + core::fmt::Debug> CopathNode<T> {
new(path: T, copath: T) -> CopathNode<T>62     pub fn new(path: T, copath: T) -> CopathNode<T> {
63         CopathNode { path, copath }
64     }
65 }
66 
67 #[derive(Clone, PartialEq, Eq, Debug)]
68 pub struct ParentSibling<T> {
69     pub parent: T,
70     pub sibling: T,
71 }
72 
73 impl<T: Clone + PartialEq + Eq + core::fmt::Debug> ParentSibling<T> {
new(parent: T, sibling: T) -> ParentSibling<T>74     pub fn new(parent: T, sibling: T) -> ParentSibling<T> {
75         ParentSibling { parent, sibling }
76     }
77 }
78 
79 macro_rules! impl_tree_stdint {
80     ($t:ty) => {
81         impl TreeIndex for $t {
82             fn root(&self) -> $t {
83                 *self - 1
84             }
85 
86             /// Panicks if `x` is even in debug, overflows in release.
87             fn left_unchecked(&self) -> Self {
88                 *self ^ (0x01 << (self.trailing_ones() - 1))
89             }
90 
91             /// Panicks if `x` is even in debug, overflows in release.
92             fn right_unchecked(&self) -> Self {
93                 *self ^ (0x03 << (self.trailing_ones() - 1))
94             }
95 
96             fn parent_sibling(&self, leaf_count: &Self) -> Option<ParentSibling<Self>> {
97                 if self == &leaf_count.root() {
98                     return None;
99                 }
100 
101                 let lvl = self.trailing_ones();
102                 let p = (self & !(1 << (lvl + 1))) | (1 << lvl);
103 
104                 let s = if *self < p {
105                     p.right_unchecked()
106                 } else {
107                     p.left_unchecked()
108                 };
109 
110                 Some(ParentSibling::new(p, s))
111             }
112 
113             fn is_leaf(&self) -> bool {
114                 self & 1 == 0
115             }
116 
117             fn is_in_tree(&self, root: &Self) -> bool {
118                 *self <= 2 * root
119             }
120 
121             #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
122             fn zero() -> Self {
123                 0
124             }
125         }
126     };
127 }
128 
129 impl_tree_stdint!(u32);
130 
131 #[cfg(test)]
132 impl_tree_stdint!(u64);
133 
leaf_lca_level(x: u32, y: u32) -> u32134 pub fn leaf_lca_level(x: u32, y: u32) -> u32 {
135     let mut xn = x;
136     let mut yn = y;
137     let mut k = 0;
138 
139     while xn != yn {
140         xn >>= 1;
141         yn >>= 1;
142         k += 1;
143     }
144 
145     k
146 }
147 
subtree(x: u32) -> (LeafIndex, LeafIndex)148 pub fn subtree(x: u32) -> (LeafIndex, LeafIndex) {
149     let breadth = 1 << x.trailing_ones();
150     (
151         LeafIndex((x + 1 - breadth) >> 1),
152         LeafIndex(((x + breadth) >> 1) + 1),
153     )
154 }
155 
156 pub struct BfsIterTopDown {
157     level: usize,
158     mask: usize,
159     level_end: usize,
160     ctr: usize,
161 }
162 
163 impl BfsIterTopDown {
new(num_leaves: usize) -> Self164     pub fn new(num_leaves: usize) -> Self {
165         let depth = num_leaves.trailing_zeros() as usize;
166         Self {
167             level: depth + 1,
168             mask: (1 << depth) - 1,
169             level_end: 1,
170             ctr: 0,
171         }
172     }
173 }
174 
175 impl Iterator for BfsIterTopDown {
176     type Item = usize;
177 
next(&mut self) -> Option<Self::Item>178     fn next(&mut self) -> Option<Self::Item> {
179         if self.ctr == self.level_end {
180             if self.level == 1 {
181                 return None;
182             }
183             self.level_end = (((self.level_end - 1) << 1) | 1) + 1;
184             self.level -= 1;
185             self.ctr = 0;
186             self.mask >>= 1;
187         }
188         let res = Some((self.ctr << self.level) | self.mask);
189         self.ctr += 1;
190         res
191     }
192 }
193 
194 #[cfg(test)]
195 mod tests {
196     use super::*;
197     use itertools::Itertools;
198     use serde::{Deserialize, Serialize};
199 
200     #[cfg(target_arch = "wasm32")]
201     use wasm_bindgen_test::wasm_bindgen_test as test;
202 
203     #[derive(Serialize, Deserialize)]
204     struct TestCase {
205         n_leaves: u32,
206         n_nodes: u32,
207         root: u32,
208         left: Vec<Option<u32>>,
209         right: Vec<Option<u32>>,
210         parent: Vec<Option<u32>>,
211         sibling: Vec<Option<u32>>,
212     }
213 
node_width(n: u32) -> u32214     pub fn node_width(n: u32) -> u32 {
215         if n == 0 {
216             0
217         } else {
218             2 * (n - 1) + 1
219         }
220     }
221 
222     #[test]
test_bfs_iterator()223     fn test_bfs_iterator() {
224         let expected = [7, 3, 11, 1, 5, 9, 13, 0, 2, 4, 6, 8, 10, 12, 14];
225         let bfs = BfsIterTopDown::new(8);
226         assert_eq!(bfs.collect::<Vec<_>>(), expected);
227     }
228 
229     #[cfg_attr(coverage_nightly, coverage(off))]
generate_tree_math_test_cases() -> Vec<TestCase>230     fn generate_tree_math_test_cases() -> Vec<TestCase> {
231         let mut test_cases = Vec::new();
232 
233         for log_n_leaves in 0..8 {
234             let n_leaves = 1 << log_n_leaves;
235             let n_nodes = node_width(n_leaves);
236             let left = (0..n_nodes).map(|x| x.left()).collect::<Vec<_>>();
237             let right = (0..n_nodes).map(|x| x.right()).collect::<Vec<_>>();
238 
239             let (parent, sibling) = (0..n_nodes)
240                 .map(|x| {
241                     x.parent_sibling(&n_leaves)
242                         .map(|ps| (ps.parent, ps.sibling))
243                         .unzip()
244                 })
245                 .unzip();
246 
247             test_cases.push(TestCase {
248                 n_leaves,
249                 n_nodes,
250                 root: n_leaves.root(),
251                 left,
252                 right,
253                 parent,
254                 sibling,
255             })
256         }
257 
258         test_cases
259     }
260 
load_test_cases() -> Vec<TestCase>261     fn load_test_cases() -> Vec<TestCase> {
262         load_test_case_json!(tree_math, generate_tree_math_test_cases())
263     }
264 
265     #[test]
test_tree_math()266     fn test_tree_math() {
267         let test_cases = load_test_cases();
268 
269         for case in test_cases {
270             assert_eq!(node_width(case.n_leaves), case.n_nodes);
271             assert_eq!(case.n_leaves.root(), case.root);
272 
273             for x in 0..case.n_nodes {
274                 assert_eq!(x.left(), case.left[x as usize]);
275                 assert_eq!(x.right(), case.right[x as usize]);
276 
277                 let (p, s) = x
278                     .parent_sibling(&case.n_leaves)
279                     .map(|ps| (ps.parent, ps.sibling))
280                     .unzip();
281 
282                 assert_eq!(p, case.parent[x as usize]);
283                 assert_eq!(s, case.sibling[x as usize]);
284             }
285         }
286     }
287 
288     #[test]
test_direct_path()289     fn test_direct_path() {
290         let expected: Vec<Vec<u32>> = [
291             [0x01, 0x03, 0x07, 0x0f].to_vec(),
292             [0x03, 0x07, 0x0f].to_vec(),
293             [0x01, 0x03, 0x07, 0x0f].to_vec(),
294             [0x07, 0x0f].to_vec(),
295             [0x05, 0x03, 0x07, 0x0f].to_vec(),
296             [0x03, 0x07, 0x0f].to_vec(),
297             [0x05, 0x03, 0x07, 0x0f].to_vec(),
298             [0x0f].to_vec(),
299             [0x09, 0x0b, 0x07, 0x0f].to_vec(),
300             [0x0b, 0x07, 0x0f].to_vec(),
301             [0x09, 0x0b, 0x07, 0x0f].to_vec(),
302             [0x07, 0x0f].to_vec(),
303             [0x0d, 0x0b, 0x07, 0x0f].to_vec(),
304             [0x0b, 0x07, 0x0f].to_vec(),
305             [0x0d, 0x0b, 0x07, 0x0f].to_vec(),
306             [].to_vec(),
307             [0x11, 0x13, 0x17, 0x0f].to_vec(),
308             [0x13, 0x17, 0x0f].to_vec(),
309             [0x11, 0x13, 0x17, 0x0f].to_vec(),
310             [0x17, 0x0f].to_vec(),
311             [0x15, 0x13, 0x17, 0x0f].to_vec(),
312             [0x13, 0x17, 0x0f].to_vec(),
313             [0x15, 0x13, 0x17, 0x0f].to_vec(),
314             [0x0f].to_vec(),
315             [0x19, 0x1b, 0x17, 0x0f].to_vec(),
316             [0x1b, 0x17, 0x0f].to_vec(),
317             [0x19, 0x1b, 0x17, 0x0f].to_vec(),
318             [0x17, 0x0f].to_vec(),
319             [0x1d, 0x1b, 0x17, 0x0f].to_vec(),
320             [0x1b, 0x17, 0x0f].to_vec(),
321             [0x1d, 0x1b, 0x17, 0x0f].to_vec(),
322         ]
323         .to_vec();
324 
325         for (i, item) in expected.iter().enumerate() {
326             let path = (i as u32)
327                 .direct_copath(&16)
328                 .into_iter()
329                 .map(|cp| cp.path)
330                 .collect_vec();
331 
332             assert_eq!(item, &path)
333         }
334     }
335 
336     #[test]
test_copath_path()337     fn test_copath_path() {
338         let expected: Vec<Vec<u32>> = [
339             [0x02, 0x05, 0x0b, 0x17].to_vec(),
340             [0x05, 0x0b, 0x17].to_vec(),
341             [0x00, 0x05, 0x0b, 0x17].to_vec(),
342             [0x0b, 0x17].to_vec(),
343             [0x06, 0x01, 0x0b, 0x17].to_vec(),
344             [0x01, 0x0b, 0x17].to_vec(),
345             [0x04, 0x01, 0x0b, 0x17].to_vec(),
346             [0x17].to_vec(),
347             [0x0a, 0x0d, 0x03, 0x17].to_vec(),
348             [0x0d, 0x03, 0x17].to_vec(),
349             [0x08, 0x0d, 0x03, 0x17].to_vec(),
350             [0x03, 0x17].to_vec(),
351             [0x0e, 0x09, 0x03, 0x17].to_vec(),
352             [0x09, 0x03, 0x17].to_vec(),
353             [0x0c, 0x09, 0x03, 0x17].to_vec(),
354             [].to_vec(),
355             [0x12, 0x15, 0x1b, 0x07].to_vec(),
356             [0x15, 0x1b, 0x07].to_vec(),
357             [0x10, 0x15, 0x1b, 0x07].to_vec(),
358             [0x1b, 0x07].to_vec(),
359             [0x16, 0x11, 0x1b, 0x07].to_vec(),
360             [0x11, 0x1b, 0x07].to_vec(),
361             [0x14, 0x11, 0x1b, 0x07].to_vec(),
362             [0x07].to_vec(),
363             [0x1a, 0x1d, 0x13, 0x07].to_vec(),
364             [0x1d, 0x13, 0x07].to_vec(),
365             [0x18, 0x1d, 0x13, 0x07].to_vec(),
366             [0x13, 0x07].to_vec(),
367             [0x1e, 0x19, 0x13, 0x07].to_vec(),
368             [0x19, 0x13, 0x07].to_vec(),
369             [0x1c, 0x19, 0x13, 0x07].to_vec(),
370         ]
371         .to_vec();
372 
373         for (i, item) in expected.iter().enumerate() {
374             let copath = (i as u32)
375                 .direct_copath(&16)
376                 .into_iter()
377                 .map(|cp| cp.copath)
378                 .collect_vec();
379 
380             assert_eq!(item, &copath)
381         }
382     }
383 }
384