1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use super::leaf_node::LeafNode; 6 use crate::client::MlsError; 7 use crate::crypto::HpkePublicKey; 8 use crate::tree_kem::math as tree_math; 9 use crate::tree_kem::parent_hash::ParentHash; 10 use alloc::vec; 11 use alloc::vec::Vec; 12 use core::hash::Hash; 13 use core::ops::{Deref, DerefMut}; 14 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 15 use tree_math::{CopathNode, TreeIndex}; 16 17 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)] 18 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 19 pub(crate) struct Parent { 20 pub public_key: HpkePublicKey, 21 pub parent_hash: ParentHash, 22 pub unmerged_leaves: Vec<LeafIndex>, 23 } 24 25 #[derive( 26 Clone, Copy, Debug, Ord, PartialEq, PartialOrd, Hash, Eq, MlsSize, MlsEncode, MlsDecode, 27 )] 28 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 29 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 30 pub struct LeafIndex(pub(crate) u32); 31 32 impl LeafIndex { new(i: u32) -> Self33 pub fn new(i: u32) -> Self { 34 Self(i) 35 } 36 } 37 38 impl Deref for LeafIndex { 39 type Target = u32; 40 deref(&self) -> &Self::Target41 fn deref(&self) -> &Self::Target { 42 &self.0 43 } 44 } 45 46 impl From<&LeafIndex> for NodeIndex { from(leaf_index: &LeafIndex) -> Self47 fn from(leaf_index: &LeafIndex) -> Self { 48 leaf_index.0 * 2 49 } 50 } 51 52 impl From<LeafIndex> for NodeIndex { from(leaf_index: LeafIndex) -> Self53 fn from(leaf_index: LeafIndex) -> Self { 54 leaf_index.0 * 2 55 } 56 } 57 58 pub(crate) type NodeIndex = u32; 59 60 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)] 61 #[allow(clippy::large_enum_variant)] 62 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 63 #[repr(u8)] 64 //TODO: Research if this should actually be a Box<Leaf> for memory / performance reasons 65 pub(crate) enum Node { 66 Leaf(LeafNode) = 1u8, 67 Parent(Parent) = 2u8, 68 } 69 70 impl Node { public_key(&self) -> &HpkePublicKey71 pub fn public_key(&self) -> &HpkePublicKey { 72 match self { 73 Node::Parent(p) => &p.public_key, 74 Node::Leaf(l) => &l.public_key, 75 } 76 } 77 } 78 79 impl From<Parent> for Option<Node> { from(p: Parent) -> Self80 fn from(p: Parent) -> Self { 81 Node::from(p).into() 82 } 83 } 84 85 impl From<LeafNode> for Option<Node> { from(l: LeafNode) -> Self86 fn from(l: LeafNode) -> Self { 87 Node::from(l).into() 88 } 89 } 90 91 impl From<Parent> for Node { from(p: Parent) -> Self92 fn from(p: Parent) -> Self { 93 Node::Parent(p) 94 } 95 } 96 97 impl From<LeafNode> for Node { from(l: LeafNode) -> Self98 fn from(l: LeafNode) -> Self { 99 Node::Leaf(l) 100 } 101 } 102 103 pub(crate) trait NodeTypeResolver { as_parent(&self) -> Result<&Parent, MlsError>104 fn as_parent(&self) -> Result<&Parent, MlsError>; as_parent_mut(&mut self) -> Result<&mut Parent, MlsError>105 fn as_parent_mut(&mut self) -> Result<&mut Parent, MlsError>; as_leaf(&self) -> Result<&LeafNode, MlsError>106 fn as_leaf(&self) -> Result<&LeafNode, MlsError>; as_leaf_mut(&mut self) -> Result<&mut LeafNode, MlsError>107 fn as_leaf_mut(&mut self) -> Result<&mut LeafNode, MlsError>; as_non_empty(&self) -> Result<&Node, MlsError>108 fn as_non_empty(&self) -> Result<&Node, MlsError>; 109 } 110 111 impl NodeTypeResolver for Option<Node> { as_parent(&self) -> Result<&Parent, MlsError>112 fn as_parent(&self) -> Result<&Parent, MlsError> { 113 self.as_ref() 114 .and_then(|n| match n { 115 Node::Parent(p) => Some(p), 116 Node::Leaf(_) => None, 117 }) 118 .ok_or(MlsError::ExpectedNode) 119 } 120 as_parent_mut(&mut self) -> Result<&mut Parent, MlsError>121 fn as_parent_mut(&mut self) -> Result<&mut Parent, MlsError> { 122 self.as_mut() 123 .and_then(|n| match n { 124 Node::Parent(p) => Some(p), 125 Node::Leaf(_) => None, 126 }) 127 .ok_or(MlsError::ExpectedNode) 128 } 129 as_leaf(&self) -> Result<&LeafNode, MlsError>130 fn as_leaf(&self) -> Result<&LeafNode, MlsError> { 131 self.as_ref() 132 .and_then(|n| match n { 133 Node::Parent(_) => None, 134 Node::Leaf(l) => Some(l), 135 }) 136 .ok_or(MlsError::ExpectedNode) 137 } 138 as_leaf_mut(&mut self) -> Result<&mut LeafNode, MlsError>139 fn as_leaf_mut(&mut self) -> Result<&mut LeafNode, MlsError> { 140 self.as_mut() 141 .and_then(|n| match n { 142 Node::Parent(_) => None, 143 Node::Leaf(l) => Some(l), 144 }) 145 .ok_or(MlsError::ExpectedNode) 146 } 147 as_non_empty(&self) -> Result<&Node, MlsError>148 fn as_non_empty(&self) -> Result<&Node, MlsError> { 149 self.as_ref().ok_or(MlsError::UnexpectedEmptyNode) 150 } 151 } 152 153 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode, Default)] 154 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 155 pub(crate) struct NodeVec(Vec<Option<Node>>); 156 157 impl From<Vec<Option<Node>>> for NodeVec { from(x: Vec<Option<Node>>) -> Self158 fn from(x: Vec<Option<Node>>) -> Self { 159 NodeVec(x) 160 } 161 } 162 163 impl Deref for NodeVec { 164 type Target = Vec<Option<Node>>; 165 deref(&self) -> &Self::Target166 fn deref(&self) -> &Self::Target { 167 &self.0 168 } 169 } 170 171 impl DerefMut for NodeVec { deref_mut(&mut self) -> &mut Self::Target172 fn deref_mut(&mut self) -> &mut Self::Target { 173 &mut self.0 174 } 175 } 176 177 impl NodeVec { 178 #[cfg(any(test, all(feature = "custom_proposal", feature = "tree_index")))] occupied_leaf_count(&self) -> u32179 pub fn occupied_leaf_count(&self) -> u32 { 180 self.non_empty_leaves().count() as u32 181 } 182 total_leaf_count(&self) -> u32183 pub fn total_leaf_count(&self) -> u32 { 184 (self.len() as u32 / 2 + 1).next_power_of_two() 185 } 186 187 #[inline] borrow_node(&self, index: NodeIndex) -> Result<&Option<Node>, MlsError>188 pub fn borrow_node(&self, index: NodeIndex) -> Result<&Option<Node>, MlsError> { 189 Ok(self.get(self.validate_index(index)?).unwrap_or(&None)) 190 } 191 validate_index(&self, index: NodeIndex) -> Result<usize, MlsError>192 fn validate_index(&self, index: NodeIndex) -> Result<usize, MlsError> { 193 if (index as usize) >= self.len().next_power_of_two() { 194 Err(MlsError::InvalidNodeIndex(index)) 195 } else { 196 Ok(index as usize) 197 } 198 } 199 200 #[cfg(test)] empty_leaves(&mut self) -> impl Iterator<Item = (LeafIndex, &mut Option<Node>)>201 fn empty_leaves(&mut self) -> impl Iterator<Item = (LeafIndex, &mut Option<Node>)> { 202 self.iter_mut() 203 .step_by(2) 204 .enumerate() 205 .filter(|(_, n)| n.is_none()) 206 .map(|(i, n)| (LeafIndex(i as u32), n)) 207 } 208 non_empty_leaves(&self) -> impl Iterator<Item = (LeafIndex, &LeafNode)> + '_209 pub fn non_empty_leaves(&self) -> impl Iterator<Item = (LeafIndex, &LeafNode)> + '_ { 210 self.leaves() 211 .enumerate() 212 .filter_map(|(i, l)| l.map(|l| (LeafIndex(i as u32), l))) 213 } 214 non_empty_parents(&self) -> impl Iterator<Item = (NodeIndex, &Parent)> + '_215 pub fn non_empty_parents(&self) -> impl Iterator<Item = (NodeIndex, &Parent)> + '_ { 216 self.iter() 217 .enumerate() 218 .skip(1) 219 .step_by(2) 220 .map(|(i, n)| (i as NodeIndex, n)) 221 .filter_map(|(i, n)| n.as_parent().ok().map(|p| (i, p))) 222 } 223 leaves(&self) -> impl Iterator<Item = Option<&LeafNode>> + '_224 pub fn leaves(&self) -> impl Iterator<Item = Option<&LeafNode>> + '_ { 225 self.iter().step_by(2).map(|n| n.as_leaf().ok()) 226 } 227 direct_copath(&self, index: LeafIndex) -> Vec<CopathNode<NodeIndex>>228 pub fn direct_copath(&self, index: LeafIndex) -> Vec<CopathNode<NodeIndex>> { 229 NodeIndex::from(index).direct_copath(&self.total_leaf_count()) 230 } 231 232 // Section 8.4 233 // The filtered direct path of a node is obtained from the node's direct path by removing 234 // all nodes whose child on the nodes's copath has an empty resolution filtered(&self, index: LeafIndex) -> Result<Vec<bool>, MlsError>235 pub fn filtered(&self, index: LeafIndex) -> Result<Vec<bool>, MlsError> { 236 Ok(NodeIndex::from(index) 237 .direct_copath(&self.total_leaf_count()) 238 .into_iter() 239 .map(|cp| self.is_resolution_empty(cp.copath)) 240 .collect()) 241 } 242 243 #[inline] is_blank(&self, index: NodeIndex) -> Result<bool, MlsError>244 pub fn is_blank(&self, index: NodeIndex) -> Result<bool, MlsError> { 245 self.borrow_node(index).map(|n| n.is_none()) 246 } 247 248 #[inline] is_leaf(&self, index: NodeIndex) -> bool249 pub fn is_leaf(&self, index: NodeIndex) -> bool { 250 index % 2 == 0 251 } 252 253 // Blank a previously filled leaf node, and return the existing leaf blank_leaf_node(&mut self, leaf_index: LeafIndex) -> Result<LeafNode, MlsError>254 pub fn blank_leaf_node(&mut self, leaf_index: LeafIndex) -> Result<LeafNode, MlsError> { 255 let node_index = self.validate_index(leaf_index.into())?; 256 257 match self.get_mut(node_index).and_then(Option::take) { 258 Some(Node::Leaf(l)) => Ok(l), 259 _ => Err(MlsError::RemovingNonExistingMember), 260 } 261 } 262 blank_direct_path(&mut self, leaf: LeafIndex) -> Result<(), MlsError>263 pub fn blank_direct_path(&mut self, leaf: LeafIndex) -> Result<(), MlsError> { 264 for i in self.direct_copath(leaf) { 265 if let Some(n) = self.get_mut(i.path as usize) { 266 *n = None 267 } 268 } 269 270 Ok(()) 271 } 272 273 // Remove elements until the last node is non-blank trim(&mut self)274 pub fn trim(&mut self) { 275 while self.last() == Some(&None) { 276 self.pop(); 277 } 278 } 279 borrow_as_parent(&self, node_index: NodeIndex) -> Result<&Parent, MlsError>280 pub fn borrow_as_parent(&self, node_index: NodeIndex) -> Result<&Parent, MlsError> { 281 self.borrow_node(node_index).and_then(|n| n.as_parent()) 282 } 283 borrow_as_parent_mut(&mut self, node_index: NodeIndex) -> Result<&mut Parent, MlsError>284 pub fn borrow_as_parent_mut(&mut self, node_index: NodeIndex) -> Result<&mut Parent, MlsError> { 285 let index = self.validate_index(node_index)?; 286 287 self.get_mut(index) 288 .ok_or(MlsError::InvalidNodeIndex(node_index))? 289 .as_parent_mut() 290 } 291 borrow_as_leaf_mut(&mut self, index: LeafIndex) -> Result<&mut LeafNode, MlsError>292 pub fn borrow_as_leaf_mut(&mut self, index: LeafIndex) -> Result<&mut LeafNode, MlsError> { 293 let node_index = NodeIndex::from(index); 294 let index = self.validate_index(node_index)?; 295 296 self.get_mut(index) 297 .ok_or(MlsError::InvalidNodeIndex(node_index))? 298 .as_leaf_mut() 299 } 300 borrow_as_leaf(&self, index: LeafIndex) -> Result<&LeafNode, MlsError>301 pub fn borrow_as_leaf(&self, index: LeafIndex) -> Result<&LeafNode, MlsError> { 302 let node_index = NodeIndex::from(index); 303 self.borrow_node(node_index).and_then(|n| n.as_leaf()) 304 } 305 borrow_or_fill_node_as_parent( &mut self, node_index: NodeIndex, public_key: &HpkePublicKey, ) -> Result<&mut Parent, MlsError>306 pub fn borrow_or_fill_node_as_parent( 307 &mut self, 308 node_index: NodeIndex, 309 public_key: &HpkePublicKey, 310 ) -> Result<&mut Parent, MlsError> { 311 let index = self.validate_index(node_index)?; 312 313 while self.len() <= index { 314 self.push(None); 315 } 316 317 self.get_mut(index) 318 .ok_or(MlsError::InvalidNodeIndex(node_index)) 319 .and_then(|n| { 320 if n.is_none() { 321 *n = Parent { 322 public_key: public_key.clone(), 323 parent_hash: ParentHash::empty(), 324 unmerged_leaves: vec![], 325 } 326 .into(); 327 } 328 n.as_parent_mut() 329 }) 330 } 331 get_resolution_index(&self, index: NodeIndex) -> Result<Vec<NodeIndex>, MlsError>332 pub fn get_resolution_index(&self, index: NodeIndex) -> Result<Vec<NodeIndex>, MlsError> { 333 let mut indexes = vec![index]; 334 let mut resolution = vec![]; 335 336 while let Some(index) = indexes.pop() { 337 if let Some(Some(node)) = self.get(index as usize) { 338 resolution.push(index); 339 340 if let Node::Parent(p) = node { 341 resolution.extend(p.unmerged_leaves.iter().map(NodeIndex::from)); 342 } 343 } else if !index.is_leaf() { 344 indexes.push(index.right_unchecked()); 345 indexes.push(index.left_unchecked()); 346 } 347 } 348 349 Ok(resolution) 350 } 351 find_in_resolution( &self, index: NodeIndex, to_find: Option<NodeIndex>, ) -> Option<usize>352 pub fn find_in_resolution( 353 &self, 354 index: NodeIndex, 355 to_find: Option<NodeIndex>, 356 ) -> Option<usize> { 357 let mut indexes = vec![index]; 358 let mut resolution_len = 0; 359 360 while let Some(index) = indexes.pop() { 361 if let Some(Some(node)) = self.get(index as usize) { 362 if Some(index) == to_find || to_find.is_none() { 363 return Some(resolution_len); 364 } 365 366 resolution_len += 1; 367 368 if let Node::Parent(p) = node { 369 indexes.extend(p.unmerged_leaves.iter().map(NodeIndex::from)); 370 } 371 } else if !index.is_leaf() { 372 indexes.push(index.right_unchecked()); 373 indexes.push(index.left_unchecked()); 374 } 375 } 376 377 None 378 } 379 is_resolution_empty(&self, index: NodeIndex) -> bool380 pub fn is_resolution_empty(&self, index: NodeIndex) -> bool { 381 self.find_in_resolution(index, None).is_none() 382 } 383 next_empty_leaf(&self, start: LeafIndex) -> LeafIndex384 pub(crate) fn next_empty_leaf(&self, start: LeafIndex) -> LeafIndex { 385 let mut n = NodeIndex::from(start) as usize; 386 387 while n < self.len() { 388 if self.0[n].is_none() { 389 return LeafIndex((n as u32) >> 1); 390 } 391 392 n += 2; 393 } 394 395 LeafIndex((self.len() as u32 + 1) >> 1) 396 } 397 398 /// If `index` fits in the current tree, inserts `leaf` at `index`. Else, inserts `leaf` as the 399 /// last leaf insert_leaf(&mut self, index: LeafIndex, leaf: LeafNode)400 pub fn insert_leaf(&mut self, index: LeafIndex, leaf: LeafNode) { 401 let node_index = (*index as usize) << 1; 402 403 if node_index > self.len() { 404 self.push(None); 405 self.push(None); 406 } else if self.is_empty() { 407 self.push(None); 408 } 409 410 self.0[node_index] = Some(leaf.into()); 411 } 412 } 413 414 #[cfg(test)] 415 pub(crate) mod test_utils { 416 use super::*; 417 use crate::{ 418 client::test_utils::TEST_CIPHER_SUITE, tree_kem::leaf_node::test_utils::get_basic_test_node, 419 }; 420 421 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] get_test_node_vec() -> NodeVec422 pub(crate) async fn get_test_node_vec() -> NodeVec { 423 let mut nodes = vec![None; 7]; 424 425 nodes[0] = get_basic_test_node(TEST_CIPHER_SUITE, "A").await.into(); 426 nodes[4] = get_basic_test_node(TEST_CIPHER_SUITE, "C").await.into(); 427 428 nodes[5] = Parent { 429 public_key: b"CD".to_vec().into(), 430 parent_hash: ParentHash::empty(), 431 unmerged_leaves: vec![LeafIndex(2)], 432 } 433 .into(); 434 435 nodes[6] = get_basic_test_node(TEST_CIPHER_SUITE, "D").await.into(); 436 437 NodeVec::from(nodes) 438 } 439 } 440 441 #[cfg(test)] 442 mod tests { 443 use super::*; 444 use crate::{ 445 client::test_utils::TEST_CIPHER_SUITE, 446 tree_kem::{ 447 leaf_node::test_utils::get_basic_test_node, node::test_utils::get_test_node_vec, 448 }, 449 }; 450 451 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] node_key_getters()452 async fn node_key_getters() { 453 let test_node_parent: Node = Parent { 454 public_key: b"pub".to_vec().into(), 455 parent_hash: ParentHash::empty(), 456 unmerged_leaves: vec![], 457 } 458 .into(); 459 460 let test_leaf = get_basic_test_node(TEST_CIPHER_SUITE, "B").await; 461 let test_node_leaf: Node = test_leaf.clone().into(); 462 463 assert_eq!(test_node_parent.public_key().as_ref(), b"pub"); 464 assert_eq!(test_node_leaf.public_key(), &test_leaf.public_key); 465 } 466 467 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_empty_leaves()468 async fn test_empty_leaves() { 469 let mut test_vec = get_test_node_vec().await; 470 let mut test_vec_clone = get_test_node_vec().await; 471 let empty_leaves: Vec<(LeafIndex, &mut Option<Node>)> = test_vec.empty_leaves().collect(); 472 assert_eq!( 473 [(LeafIndex(1), &mut test_vec_clone[2])].as_ref(), 474 empty_leaves.as_slice() 475 ); 476 } 477 478 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_direct_path()479 async fn test_direct_path() { 480 let test_vec = get_test_node_vec().await; 481 // Tree math is already tested in that module, just ensure equality 482 let expected = 0.direct_copath(&4); 483 let actual = test_vec.direct_copath(LeafIndex(0)); 484 assert_eq!(actual, expected); 485 } 486 487 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_filtered_direct_path_co_path()488 async fn test_filtered_direct_path_co_path() { 489 let test_vec = get_test_node_vec().await; 490 let expected = [true, false]; 491 let actual = test_vec.filtered(LeafIndex(0)).unwrap(); 492 assert_eq!(actual, expected); 493 } 494 495 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_get_parent_node()496 async fn test_get_parent_node() { 497 let mut test_vec = get_test_node_vec().await; 498 499 // If the node is a leaf it should fail 500 assert!(test_vec.borrow_as_parent_mut(0).is_err()); 501 502 // If the node index is out of range it should fail 503 assert!(test_vec 504 .borrow_as_parent_mut(test_vec.len() as u32) 505 .is_err()); 506 507 // Otherwise it should succeed 508 let mut expected = Parent { 509 public_key: b"CD".to_vec().into(), 510 parent_hash: ParentHash::empty(), 511 unmerged_leaves: vec![LeafIndex(2)], 512 }; 513 514 assert_eq!(test_vec.borrow_as_parent_mut(5).unwrap(), &mut expected); 515 } 516 517 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_get_resolution()518 async fn test_get_resolution() { 519 let test_vec = get_test_node_vec().await; 520 521 let resolution_node_5 = test_vec.get_resolution_index(5).unwrap(); 522 let resolution_node_2 = test_vec.get_resolution_index(2).unwrap(); 523 let resolution_node_3 = test_vec.get_resolution_index(3).unwrap(); 524 525 assert_eq!(&resolution_node_5, &[5, 4]); 526 assert!(resolution_node_2.is_empty()); 527 assert_eq!(&resolution_node_3, &[0, 5, 4]); 528 } 529 530 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_get_or_fill_existing()531 async fn test_get_or_fill_existing() { 532 let mut test_vec = get_test_node_vec().await; 533 let mut test_vec2 = test_vec.clone(); 534 535 let expected = test_vec[5].as_parent_mut().unwrap(); 536 let actual = test_vec2 537 .borrow_or_fill_node_as_parent(5, &Vec::new().into()) 538 .unwrap(); 539 540 assert_eq!(actual, expected); 541 } 542 543 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_get_or_fill_empty()544 async fn test_get_or_fill_empty() { 545 let mut test_vec = get_test_node_vec().await; 546 547 let mut expected = Parent { 548 public_key: vec![0u8; 4].into(), 549 parent_hash: ParentHash::empty(), 550 unmerged_leaves: vec![], 551 }; 552 553 let actual = test_vec 554 .borrow_or_fill_node_as_parent(1, &vec![0u8; 4].into()) 555 .unwrap(); 556 557 assert_eq!(actual, &mut expected); 558 } 559 560 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_leaf_count()561 async fn test_leaf_count() { 562 let test_vec = get_test_node_vec().await; 563 assert_eq!(test_vec.len(), 7); 564 assert_eq!(test_vec.occupied_leaf_count(), 3); 565 assert_eq!( 566 test_vec.non_empty_leaves().count(), 567 test_vec.occupied_leaf_count() as usize 568 ); 569 } 570 571 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_total_leaf_count()572 async fn test_total_leaf_count() { 573 let test_vec = get_test_node_vec().await; 574 assert_eq!(test_vec.occupied_leaf_count(), 3); 575 assert_eq!(test_vec.total_leaf_count(), 4); 576 } 577 } 578