1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use super::{Extension, ExtensionError, ExtensionType, MlsExtension}; 6 use alloc::vec::Vec; 7 use core::ops::Deref; 8 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 9 10 /// A collection of MLS [Extensions](super::Extension). 11 /// 12 /// 13 /// # Warning 14 /// 15 /// Extension lists require that each type of extension has at most one entry. 16 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 17 #[cfg_attr( 18 all(feature = "ffi", not(test)), 19 safer_ffi_gen::ffi_type(clone, opaque) 20 )] 21 #[derive(Debug, Clone, Default, MlsSize, MlsEncode, Eq)] 22 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 23 pub struct ExtensionList(Vec<Extension>); 24 25 impl Deref for ExtensionList { 26 type Target = Vec<Extension>; 27 deref(&self) -> &Self::Target28 fn deref(&self) -> &Self::Target { 29 &self.0 30 } 31 } 32 33 impl PartialEq for ExtensionList { eq(&self, other: &Self) -> bool34 fn eq(&self, other: &Self) -> bool { 35 self.len() == other.len() 36 && self 37 .iter() 38 .all(|ext| other.get(ext.extension_type).as_ref() == Some(ext)) 39 } 40 } 41 42 impl MlsDecode for ExtensionList { mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error>43 fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> { 44 mls_rs_codec::iter::mls_decode_collection(reader, |data| { 45 let mut list = ExtensionList::new(); 46 47 while !data.is_empty() { 48 let ext = Extension::mls_decode(data)?; 49 let ext_type = ext.extension_type; 50 51 if list.0.iter().any(|e| e.extension_type == ext_type) { 52 // #[cfg(feature = "std")] 53 // return Err(mls_rs_codec::Error::Custom(format!( 54 // "Extension list has duplicate extension of type {ext_type:?}" 55 // ))); 56 57 // #[cfg(not(feature = "std"))] 58 return Err(mls_rs_codec::Error::Custom(1)); 59 } 60 61 list.0.push(ext); 62 } 63 64 Ok(list) 65 }) 66 } 67 } 68 69 impl From<Vec<Extension>> for ExtensionList { from(extensions: Vec<Extension>) -> Self70 fn from(extensions: Vec<Extension>) -> Self { 71 extensions.into_iter().collect() 72 } 73 } 74 75 impl Extend<Extension> for ExtensionList { extend<T: IntoIterator<Item = Extension>>(&mut self, iter: T)76 fn extend<T: IntoIterator<Item = Extension>>(&mut self, iter: T) { 77 iter.into_iter().for_each(|ext| self.set(ext)); 78 } 79 } 80 81 impl FromIterator<Extension> for ExtensionList { from_iter<T: IntoIterator<Item = Extension>>(iter: T) -> Self82 fn from_iter<T: IntoIterator<Item = Extension>>(iter: T) -> Self { 83 let mut list = Self::new(); 84 list.extend(iter); 85 list 86 } 87 } 88 89 impl ExtensionList { 90 /// Create a new empty extension list. new() -> ExtensionList91 pub fn new() -> ExtensionList { 92 Default::default() 93 } 94 95 /// Retrieve an extension by providing a type that implements the 96 /// [MlsExtension](super::MlsExtension) trait. 97 /// 98 /// Returns an error if the underlying deserialization of the extension 99 /// data fails. get_as<E: MlsExtension>(&self) -> Result<Option<E>, ExtensionError>100 pub fn get_as<E: MlsExtension>(&self) -> Result<Option<E>, ExtensionError> { 101 self.0 102 .iter() 103 .find(|e| e.extension_type == E::extension_type()) 104 .map(E::from_extension) 105 .transpose() 106 } 107 108 /// Determine if a specific extension exists within the list. has_extension(&self, ext_id: ExtensionType) -> bool109 pub fn has_extension(&self, ext_id: ExtensionType) -> bool { 110 self.0.iter().any(|e| e.extension_type == ext_id) 111 } 112 113 /// Set an extension in the list based on a provided type that implements 114 /// the [MlsExtension](super::MlsExtension) trait. 115 /// 116 /// If there is already an entry in the list for the same extension type, 117 /// then the prior value is removed as part of the insertion. 118 /// 119 /// This function will return an error if `ext` fails to serialize 120 /// properly. set_from<E: MlsExtension>(&mut self, ext: E) -> Result<(), ExtensionError>121 pub fn set_from<E: MlsExtension>(&mut self, ext: E) -> Result<(), ExtensionError> { 122 let ext = ext.into_extension()?; 123 self.set(ext); 124 Ok(()) 125 } 126 127 /// Set an extension in the list based on a raw 128 /// [Extension](super::Extension) value. 129 /// 130 /// If there is already an entry in the list for the same extension type, 131 /// then the prior value is removed as part of the insertion. set(&mut self, ext: Extension)132 pub fn set(&mut self, ext: Extension) { 133 let mut found = self 134 .0 135 .iter_mut() 136 .find(|e| e.extension_type == ext.extension_type); 137 138 if let Some(found) = found.take() { 139 *found = ext; 140 } else { 141 self.0.push(ext); 142 } 143 } 144 145 /// Get a raw [Extension](super::Extension) value based on an 146 /// [ExtensionType](super::ExtensionType). get(&self, extension_type: ExtensionType) -> Option<Extension>147 pub fn get(&self, extension_type: ExtensionType) -> Option<Extension> { 148 self.0 149 .iter() 150 .find(|e| e.extension_type == extension_type) 151 .cloned() 152 } 153 154 /// Remove an extension from the list by 155 /// [ExtensionType](super::ExtensionType) remove(&mut self, ext_type: ExtensionType)156 pub fn remove(&mut self, ext_type: ExtensionType) { 157 self.0.retain(|e| e.extension_type != ext_type) 158 } 159 160 /// Append another extension list to this one. 161 /// 162 /// If there is already an entry in the list for the same extension type, 163 /// then the existing value is removed. append(&mut self, others: Self)164 pub fn append(&mut self, others: Self) { 165 self.0.extend(others.0); 166 } 167 } 168 169 #[cfg(test)] 170 mod tests { 171 use alloc::vec; 172 use alloc::vec::Vec; 173 use assert_matches::assert_matches; 174 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 175 176 use crate::extension::{ 177 list::ExtensionList, Extension, ExtensionType, MlsCodecExtension, MlsExtension, 178 }; 179 180 #[derive(Debug, Clone, MlsSize, MlsEncode, MlsDecode, PartialEq, Eq)] 181 struct TestExtensionA(u32); 182 183 #[derive(Debug, Clone, MlsEncode, MlsDecode, MlsSize, PartialEq, Eq)] 184 struct TestExtensionB(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>); 185 186 #[derive(Debug, Clone, MlsEncode, MlsDecode, MlsSize, PartialEq, Eq)] 187 struct TestExtensionC(u8); 188 189 impl MlsCodecExtension for TestExtensionA { extension_type() -> ExtensionType190 fn extension_type() -> ExtensionType { 191 ExtensionType(128) 192 } 193 } 194 195 impl MlsCodecExtension for TestExtensionB { extension_type() -> ExtensionType196 fn extension_type() -> ExtensionType { 197 ExtensionType(129) 198 } 199 } 200 201 impl MlsCodecExtension for TestExtensionC { extension_type() -> ExtensionType202 fn extension_type() -> ExtensionType { 203 ExtensionType(130) 204 } 205 } 206 207 #[test] test_extension_list_get_set_from_get_as()208 fn test_extension_list_get_set_from_get_as() { 209 let mut list = ExtensionList::new(); 210 211 let ext_a = TestExtensionA(0); 212 let ext_b = TestExtensionB(vec![1]); 213 214 // Add the extensions to the list 215 list.set_from(ext_a.clone()).unwrap(); 216 list.set_from(ext_b.clone()).unwrap(); 217 218 assert_eq!(list.len(), 2); 219 assert_eq!(list.get_as::<TestExtensionA>().unwrap(), Some(ext_a)); 220 assert_eq!(list.get_as::<TestExtensionB>().unwrap(), Some(ext_b)); 221 } 222 223 #[test] test_extension_list_get_set()224 fn test_extension_list_get_set() { 225 let mut list = ExtensionList::new(); 226 227 let ext_a = Extension::new(ExtensionType(254), vec![0, 1, 2]); 228 let ext_b = Extension::new(ExtensionType(255), vec![4, 5, 6]); 229 230 // Add the extensions to the list 231 list.set(ext_a.clone()); 232 list.set(ext_b.clone()); 233 234 assert_eq!(list.len(), 2); 235 assert_eq!(list.get(ExtensionType(254)), Some(ext_a)); 236 assert_eq!(list.get(ExtensionType(255)), Some(ext_b)); 237 } 238 239 #[test] extension_list_can_overwrite_values()240 fn extension_list_can_overwrite_values() { 241 let mut list = ExtensionList::new(); 242 243 let ext_1 = TestExtensionA(0); 244 let ext_2 = TestExtensionA(1); 245 246 list.set_from(ext_1).unwrap(); 247 list.set_from(ext_2.clone()).unwrap(); 248 249 assert_eq!(list.get_as::<TestExtensionA>().unwrap(), Some(ext_2)); 250 } 251 252 #[test] extension_list_will_return_none_for_type_not_stored()253 fn extension_list_will_return_none_for_type_not_stored() { 254 let mut list = ExtensionList::new(); 255 256 assert!(list.get_as::<TestExtensionA>().unwrap().is_none()); 257 258 assert!(list 259 .get(<TestExtensionA as MlsCodecExtension>::extension_type()) 260 .is_none()); 261 262 list.set_from(TestExtensionA(1)).unwrap(); 263 264 assert!(list.get_as::<TestExtensionB>().unwrap().is_none()); 265 266 assert!(list 267 .get(<TestExtensionB as MlsCodecExtension>::extension_type()) 268 .is_none()); 269 } 270 271 #[test] test_extension_list_has_ext()272 fn test_extension_list_has_ext() { 273 let mut list = ExtensionList::new(); 274 275 let ext = TestExtensionA(255); 276 277 list.set_from(ext).unwrap(); 278 279 assert!(list.has_extension(<TestExtensionA as MlsCodecExtension>::extension_type())); 280 assert!(!list.has_extension(42.into())); 281 } 282 283 #[derive(MlsEncode, MlsSize)] 284 struct ExtensionsVec(Vec<Extension>); 285 286 #[test] extension_list_is_serialized_like_a_sequence_of_extensions()287 fn extension_list_is_serialized_like_a_sequence_of_extensions() { 288 let extension_vec = vec![ 289 Extension::new(ExtensionType(128), vec![0, 1, 2, 3]), 290 Extension::new(ExtensionType(129), vec![1, 2, 3, 4]), 291 ]; 292 293 let extension_list: ExtensionList = ExtensionList::from(extension_vec.clone()); 294 295 assert_eq!( 296 ExtensionsVec(extension_vec).mls_encode_to_vec().unwrap(), 297 extension_list.mls_encode_to_vec().unwrap(), 298 ); 299 } 300 301 #[test] deserializing_extension_list_fails_on_duplicate_extension()302 fn deserializing_extension_list_fails_on_duplicate_extension() { 303 let extensions = ExtensionsVec(vec![ 304 TestExtensionA(1).into_extension().unwrap(), 305 TestExtensionA(2).into_extension().unwrap(), 306 ]); 307 308 let serialized_extensions = extensions.mls_encode_to_vec().unwrap(); 309 310 assert_matches!( 311 ExtensionList::mls_decode(&mut &*serialized_extensions), 312 Err(mls_rs_codec::Error::Custom(_)) 313 ); 314 } 315 316 #[test] extension_list_equality_does_not_consider_order()317 fn extension_list_equality_does_not_consider_order() { 318 let extensions = [ 319 TestExtensionA(33).into_extension().unwrap(), 320 TestExtensionC(34).into_extension().unwrap(), 321 ]; 322 323 let a = extensions.iter().cloned().collect::<ExtensionList>(); 324 let b = extensions.iter().rev().cloned().collect::<ExtensionList>(); 325 326 assert_eq!(a, b); 327 } 328 329 #[test] extending_extension_list_maintains_extension_uniqueness()330 fn extending_extension_list_maintains_extension_uniqueness() { 331 let mut list = ExtensionList::new(); 332 list.set_from(TestExtensionA(33)).unwrap(); 333 list.set_from(TestExtensionC(34)).unwrap(); 334 list.extend([ 335 TestExtensionA(35).into_extension().unwrap(), 336 TestExtensionB(vec![36]).into_extension().unwrap(), 337 TestExtensionA(37).into_extension().unwrap(), 338 ]); 339 340 let expected = ExtensionList(vec![ 341 TestExtensionA(37).into_extension().unwrap(), 342 TestExtensionB(vec![36]).into_extension().unwrap(), 343 TestExtensionC(34).into_extension().unwrap(), 344 ]); 345 346 assert_eq!(list, expected); 347 } 348 349 #[test] extension_list_from_vec_maintains_extension_uniqueness()350 fn extension_list_from_vec_maintains_extension_uniqueness() { 351 let list = ExtensionList::from(vec![ 352 TestExtensionA(33).into_extension().unwrap(), 353 TestExtensionC(34).into_extension().unwrap(), 354 TestExtensionA(35).into_extension().unwrap(), 355 ]); 356 357 let expected = ExtensionList(vec![ 358 TestExtensionA(35).into_extension().unwrap(), 359 TestExtensionC(34).into_extension().unwrap(), 360 ]); 361 362 assert_eq!(list, expected); 363 } 364 } 365