1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5 use alloc::vec::Vec;
6 use core::{fmt::Debug, hash::Hash};
7 use mls_rs_codec::{MlsDecode, MlsEncode};
8
9 use super::node::LeafIndex;
10
11 pub trait TreeIndex:
12 Send + Sync + Eq + Clone + Debug + Default + MlsEncode + MlsDecode + Hash + Ord
13 {
root(&self) -> Self14 fn root(&self) -> Self;
15
left_unchecked(&self) -> Self16 fn left_unchecked(&self) -> Self;
right_unchecked(&self) -> Self17 fn right_unchecked(&self) -> Self;
18
parent_sibling(&self, leaf_count: &Self) -> Option<ParentSibling<Self>>19 fn parent_sibling(&self, leaf_count: &Self) -> Option<ParentSibling<Self>>;
is_leaf(&self) -> bool20 fn is_leaf(&self) -> bool;
is_in_tree(&self, root: &Self) -> bool21 fn is_in_tree(&self, root: &Self) -> bool;
22
23 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
zero() -> Self24 fn zero() -> Self;
25
26 #[cfg(any(feature = "secret_tree_access", feature = "private_message", test))]
left(&self) -> Option<Self>27 fn left(&self) -> Option<Self> {
28 (!self.is_leaf()).then(|| self.left_unchecked())
29 }
30
31 #[cfg(any(feature = "secret_tree_access", feature = "private_message", test))]
right(&self) -> Option<Self>32 fn right(&self) -> Option<Self> {
33 (!self.is_leaf()).then(|| self.right_unchecked())
34 }
35
direct_copath(&self, leaf_count: &Self) -> Vec<CopathNode<Self>>36 fn direct_copath(&self, leaf_count: &Self) -> Vec<CopathNode<Self>> {
37 let root = leaf_count.root();
38
39 if !self.is_in_tree(&root) {
40 return Vec::new();
41 }
42
43 let mut path = Vec::new();
44 let mut parent = self.clone();
45
46 while let Some(ps) = parent.parent_sibling(leaf_count) {
47 path.push(CopathNode::new(ps.parent.clone(), ps.sibling));
48 parent = ps.parent;
49 }
50
51 path
52 }
53 }
54
55 #[derive(Clone, PartialEq, Eq, Debug)]
56 pub struct CopathNode<T> {
57 pub path: T,
58 pub copath: T,
59 }
60
61 impl<T: Clone + PartialEq + Eq + core::fmt::Debug> CopathNode<T> {
new(path: T, copath: T) -> CopathNode<T>62 pub fn new(path: T, copath: T) -> CopathNode<T> {
63 CopathNode { path, copath }
64 }
65 }
66
67 #[derive(Clone, PartialEq, Eq, Debug)]
68 pub struct ParentSibling<T> {
69 pub parent: T,
70 pub sibling: T,
71 }
72
73 impl<T: Clone + PartialEq + Eq + core::fmt::Debug> ParentSibling<T> {
new(parent: T, sibling: T) -> ParentSibling<T>74 pub fn new(parent: T, sibling: T) -> ParentSibling<T> {
75 ParentSibling { parent, sibling }
76 }
77 }
78
79 macro_rules! impl_tree_stdint {
80 ($t:ty) => {
81 impl TreeIndex for $t {
82 fn root(&self) -> $t {
83 *self - 1
84 }
85
86 /// Panicks if `x` is even in debug, overflows in release.
87 fn left_unchecked(&self) -> Self {
88 *self ^ (0x01 << (self.trailing_ones() - 1))
89 }
90
91 /// Panicks if `x` is even in debug, overflows in release.
92 fn right_unchecked(&self) -> Self {
93 *self ^ (0x03 << (self.trailing_ones() - 1))
94 }
95
96 fn parent_sibling(&self, leaf_count: &Self) -> Option<ParentSibling<Self>> {
97 if self == &leaf_count.root() {
98 return None;
99 }
100
101 let lvl = self.trailing_ones();
102 let p = (self & !(1 << (lvl + 1))) | (1 << lvl);
103
104 let s = if *self < p {
105 p.right_unchecked()
106 } else {
107 p.left_unchecked()
108 };
109
110 Some(ParentSibling::new(p, s))
111 }
112
113 fn is_leaf(&self) -> bool {
114 self & 1 == 0
115 }
116
117 fn is_in_tree(&self, root: &Self) -> bool {
118 *self <= 2 * root
119 }
120
121 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
122 fn zero() -> Self {
123 0
124 }
125 }
126 };
127 }
128
129 impl_tree_stdint!(u32);
130
131 #[cfg(test)]
132 impl_tree_stdint!(u64);
133
leaf_lca_level(x: u32, y: u32) -> u32134 pub fn leaf_lca_level(x: u32, y: u32) -> u32 {
135 let mut xn = x;
136 let mut yn = y;
137 let mut k = 0;
138
139 while xn != yn {
140 xn >>= 1;
141 yn >>= 1;
142 k += 1;
143 }
144
145 k
146 }
147
subtree(x: u32) -> (LeafIndex, LeafIndex)148 pub fn subtree(x: u32) -> (LeafIndex, LeafIndex) {
149 let breadth = 1 << x.trailing_ones();
150 (
151 LeafIndex((x + 1 - breadth) >> 1),
152 LeafIndex(((x + breadth) >> 1) + 1),
153 )
154 }
155
156 pub struct BfsIterTopDown {
157 level: usize,
158 mask: usize,
159 level_end: usize,
160 ctr: usize,
161 }
162
163 impl BfsIterTopDown {
new(num_leaves: usize) -> Self164 pub fn new(num_leaves: usize) -> Self {
165 let depth = num_leaves.trailing_zeros() as usize;
166 Self {
167 level: depth + 1,
168 mask: (1 << depth) - 1,
169 level_end: 1,
170 ctr: 0,
171 }
172 }
173 }
174
175 impl Iterator for BfsIterTopDown {
176 type Item = usize;
177
next(&mut self) -> Option<Self::Item>178 fn next(&mut self) -> Option<Self::Item> {
179 if self.ctr == self.level_end {
180 if self.level == 1 {
181 return None;
182 }
183 self.level_end = (((self.level_end - 1) << 1) | 1) + 1;
184 self.level -= 1;
185 self.ctr = 0;
186 self.mask >>= 1;
187 }
188 let res = Some((self.ctr << self.level) | self.mask);
189 self.ctr += 1;
190 res
191 }
192 }
193
194 #[cfg(test)]
195 mod tests {
196 use super::*;
197 use itertools::Itertools;
198 use serde::{Deserialize, Serialize};
199
200 #[cfg(target_arch = "wasm32")]
201 use wasm_bindgen_test::wasm_bindgen_test as test;
202
203 #[derive(Serialize, Deserialize)]
204 struct TestCase {
205 n_leaves: u32,
206 n_nodes: u32,
207 root: u32,
208 left: Vec<Option<u32>>,
209 right: Vec<Option<u32>>,
210 parent: Vec<Option<u32>>,
211 sibling: Vec<Option<u32>>,
212 }
213
node_width(n: u32) -> u32214 pub fn node_width(n: u32) -> u32 {
215 if n == 0 {
216 0
217 } else {
218 2 * (n - 1) + 1
219 }
220 }
221
222 #[test]
test_bfs_iterator()223 fn test_bfs_iterator() {
224 let expected = [7, 3, 11, 1, 5, 9, 13, 0, 2, 4, 6, 8, 10, 12, 14];
225 let bfs = BfsIterTopDown::new(8);
226 assert_eq!(bfs.collect::<Vec<_>>(), expected);
227 }
228
229 #[cfg_attr(coverage_nightly, coverage(off))]
generate_tree_math_test_cases() -> Vec<TestCase>230 fn generate_tree_math_test_cases() -> Vec<TestCase> {
231 let mut test_cases = Vec::new();
232
233 for log_n_leaves in 0..8 {
234 let n_leaves = 1 << log_n_leaves;
235 let n_nodes = node_width(n_leaves);
236 let left = (0..n_nodes).map(|x| x.left()).collect::<Vec<_>>();
237 let right = (0..n_nodes).map(|x| x.right()).collect::<Vec<_>>();
238
239 let (parent, sibling) = (0..n_nodes)
240 .map(|x| {
241 x.parent_sibling(&n_leaves)
242 .map(|ps| (ps.parent, ps.sibling))
243 .unzip()
244 })
245 .unzip();
246
247 test_cases.push(TestCase {
248 n_leaves,
249 n_nodes,
250 root: n_leaves.root(),
251 left,
252 right,
253 parent,
254 sibling,
255 })
256 }
257
258 test_cases
259 }
260
load_test_cases() -> Vec<TestCase>261 fn load_test_cases() -> Vec<TestCase> {
262 load_test_case_json!(tree_math, generate_tree_math_test_cases())
263 }
264
265 #[test]
test_tree_math()266 fn test_tree_math() {
267 let test_cases = load_test_cases();
268
269 for case in test_cases {
270 assert_eq!(node_width(case.n_leaves), case.n_nodes);
271 assert_eq!(case.n_leaves.root(), case.root);
272
273 for x in 0..case.n_nodes {
274 assert_eq!(x.left(), case.left[x as usize]);
275 assert_eq!(x.right(), case.right[x as usize]);
276
277 let (p, s) = x
278 .parent_sibling(&case.n_leaves)
279 .map(|ps| (ps.parent, ps.sibling))
280 .unzip();
281
282 assert_eq!(p, case.parent[x as usize]);
283 assert_eq!(s, case.sibling[x as usize]);
284 }
285 }
286 }
287
288 #[test]
test_direct_path()289 fn test_direct_path() {
290 let expected: Vec<Vec<u32>> = [
291 [0x01, 0x03, 0x07, 0x0f].to_vec(),
292 [0x03, 0x07, 0x0f].to_vec(),
293 [0x01, 0x03, 0x07, 0x0f].to_vec(),
294 [0x07, 0x0f].to_vec(),
295 [0x05, 0x03, 0x07, 0x0f].to_vec(),
296 [0x03, 0x07, 0x0f].to_vec(),
297 [0x05, 0x03, 0x07, 0x0f].to_vec(),
298 [0x0f].to_vec(),
299 [0x09, 0x0b, 0x07, 0x0f].to_vec(),
300 [0x0b, 0x07, 0x0f].to_vec(),
301 [0x09, 0x0b, 0x07, 0x0f].to_vec(),
302 [0x07, 0x0f].to_vec(),
303 [0x0d, 0x0b, 0x07, 0x0f].to_vec(),
304 [0x0b, 0x07, 0x0f].to_vec(),
305 [0x0d, 0x0b, 0x07, 0x0f].to_vec(),
306 [].to_vec(),
307 [0x11, 0x13, 0x17, 0x0f].to_vec(),
308 [0x13, 0x17, 0x0f].to_vec(),
309 [0x11, 0x13, 0x17, 0x0f].to_vec(),
310 [0x17, 0x0f].to_vec(),
311 [0x15, 0x13, 0x17, 0x0f].to_vec(),
312 [0x13, 0x17, 0x0f].to_vec(),
313 [0x15, 0x13, 0x17, 0x0f].to_vec(),
314 [0x0f].to_vec(),
315 [0x19, 0x1b, 0x17, 0x0f].to_vec(),
316 [0x1b, 0x17, 0x0f].to_vec(),
317 [0x19, 0x1b, 0x17, 0x0f].to_vec(),
318 [0x17, 0x0f].to_vec(),
319 [0x1d, 0x1b, 0x17, 0x0f].to_vec(),
320 [0x1b, 0x17, 0x0f].to_vec(),
321 [0x1d, 0x1b, 0x17, 0x0f].to_vec(),
322 ]
323 .to_vec();
324
325 for (i, item) in expected.iter().enumerate() {
326 let path = (i as u32)
327 .direct_copath(&16)
328 .into_iter()
329 .map(|cp| cp.path)
330 .collect_vec();
331
332 assert_eq!(item, &path)
333 }
334 }
335
336 #[test]
test_copath_path()337 fn test_copath_path() {
338 let expected: Vec<Vec<u32>> = [
339 [0x02, 0x05, 0x0b, 0x17].to_vec(),
340 [0x05, 0x0b, 0x17].to_vec(),
341 [0x00, 0x05, 0x0b, 0x17].to_vec(),
342 [0x0b, 0x17].to_vec(),
343 [0x06, 0x01, 0x0b, 0x17].to_vec(),
344 [0x01, 0x0b, 0x17].to_vec(),
345 [0x04, 0x01, 0x0b, 0x17].to_vec(),
346 [0x17].to_vec(),
347 [0x0a, 0x0d, 0x03, 0x17].to_vec(),
348 [0x0d, 0x03, 0x17].to_vec(),
349 [0x08, 0x0d, 0x03, 0x17].to_vec(),
350 [0x03, 0x17].to_vec(),
351 [0x0e, 0x09, 0x03, 0x17].to_vec(),
352 [0x09, 0x03, 0x17].to_vec(),
353 [0x0c, 0x09, 0x03, 0x17].to_vec(),
354 [].to_vec(),
355 [0x12, 0x15, 0x1b, 0x07].to_vec(),
356 [0x15, 0x1b, 0x07].to_vec(),
357 [0x10, 0x15, 0x1b, 0x07].to_vec(),
358 [0x1b, 0x07].to_vec(),
359 [0x16, 0x11, 0x1b, 0x07].to_vec(),
360 [0x11, 0x1b, 0x07].to_vec(),
361 [0x14, 0x11, 0x1b, 0x07].to_vec(),
362 [0x07].to_vec(),
363 [0x1a, 0x1d, 0x13, 0x07].to_vec(),
364 [0x1d, 0x13, 0x07].to_vec(),
365 [0x18, 0x1d, 0x13, 0x07].to_vec(),
366 [0x13, 0x07].to_vec(),
367 [0x1e, 0x19, 0x13, 0x07].to_vec(),
368 [0x19, 0x13, 0x07].to_vec(),
369 [0x1c, 0x19, 0x13, 0x07].to_vec(),
370 ]
371 .to_vec();
372
373 for (i, item) in expected.iter().enumerate() {
374 let copath = (i as u32)
375 .direct_copath(&16)
376 .into_iter()
377 .map(|cp| cp.copath)
378 .collect_vec();
379
380 assert_eq!(item, &copath)
381 }
382 }
383 }
384