1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use super::{Extension, ExtensionError, ExtensionType, MlsExtension};
6 use alloc::vec::Vec;
7 use core::ops::Deref;
8 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
9 
10 /// A collection of MLS [Extensions](super::Extension).
11 ///
12 ///
13 /// # Warning
14 ///
15 /// Extension lists require that each type of extension has at most one entry.
16 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
17 #[cfg_attr(
18     all(feature = "ffi", not(test)),
19     safer_ffi_gen::ffi_type(clone, opaque)
20 )]
21 #[derive(Debug, Clone, Default, MlsSize, MlsEncode, Eq)]
22 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23 pub struct ExtensionList(Vec<Extension>);
24 
25 impl Deref for ExtensionList {
26     type Target = Vec<Extension>;
27 
deref(&self) -> &Self::Target28     fn deref(&self) -> &Self::Target {
29         &self.0
30     }
31 }
32 
33 impl PartialEq for ExtensionList {
eq(&self, other: &Self) -> bool34     fn eq(&self, other: &Self) -> bool {
35         self.len() == other.len()
36             && self
37                 .iter()
38                 .all(|ext| other.get(ext.extension_type).as_ref() == Some(ext))
39     }
40 }
41 
42 impl MlsDecode for ExtensionList {
mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error>43     fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
44         mls_rs_codec::iter::mls_decode_collection(reader, |data| {
45             let mut list = ExtensionList::new();
46 
47             while !data.is_empty() {
48                 let ext = Extension::mls_decode(data)?;
49                 let ext_type = ext.extension_type;
50 
51                 if list.0.iter().any(|e| e.extension_type == ext_type) {
52                     // #[cfg(feature = "std")]
53                     // return Err(mls_rs_codec::Error::Custom(format!(
54                     //    "Extension list has duplicate extension of type {ext_type:?}"
55                     // )));
56 
57                     // #[cfg(not(feature = "std"))]
58                     return Err(mls_rs_codec::Error::Custom(1));
59                 }
60 
61                 list.0.push(ext);
62             }
63 
64             Ok(list)
65         })
66     }
67 }
68 
69 impl From<Vec<Extension>> for ExtensionList {
from(extensions: Vec<Extension>) -> Self70     fn from(extensions: Vec<Extension>) -> Self {
71         extensions.into_iter().collect()
72     }
73 }
74 
75 impl Extend<Extension> for ExtensionList {
extend<T: IntoIterator<Item = Extension>>(&mut self, iter: T)76     fn extend<T: IntoIterator<Item = Extension>>(&mut self, iter: T) {
77         iter.into_iter().for_each(|ext| self.set(ext));
78     }
79 }
80 
81 impl FromIterator<Extension> for ExtensionList {
from_iter<T: IntoIterator<Item = Extension>>(iter: T) -> Self82     fn from_iter<T: IntoIterator<Item = Extension>>(iter: T) -> Self {
83         let mut list = Self::new();
84         list.extend(iter);
85         list
86     }
87 }
88 
89 impl ExtensionList {
90     /// Create a new empty extension list.
new() -> ExtensionList91     pub fn new() -> ExtensionList {
92         Default::default()
93     }
94 
95     /// Retrieve an extension by providing a type that implements the
96     /// [MlsExtension](super::MlsExtension) trait.
97     ///
98     /// Returns an error if the underlying deserialization of the extension
99     /// data fails.
get_as<E: MlsExtension>(&self) -> Result<Option<E>, ExtensionError>100     pub fn get_as<E: MlsExtension>(&self) -> Result<Option<E>, ExtensionError> {
101         self.0
102             .iter()
103             .find(|e| e.extension_type == E::extension_type())
104             .map(E::from_extension)
105             .transpose()
106     }
107 
108     /// Determine if a specific extension exists within the list.
has_extension(&self, ext_id: ExtensionType) -> bool109     pub fn has_extension(&self, ext_id: ExtensionType) -> bool {
110         self.0.iter().any(|e| e.extension_type == ext_id)
111     }
112 
113     /// Set an extension in the list based on a provided type that implements
114     /// the [MlsExtension](super::MlsExtension) trait.
115     ///
116     /// If there is already an entry in the list for the same extension type,
117     /// then the prior value is removed as part of the insertion.
118     ///
119     /// This function will return an error if `ext` fails to serialize
120     /// properly.
set_from<E: MlsExtension>(&mut self, ext: E) -> Result<(), ExtensionError>121     pub fn set_from<E: MlsExtension>(&mut self, ext: E) -> Result<(), ExtensionError> {
122         let ext = ext.into_extension()?;
123         self.set(ext);
124         Ok(())
125     }
126 
127     /// Set an extension in the list based on a raw
128     /// [Extension](super::Extension) value.
129     ///
130     /// If there is already an entry in the list for the same extension type,
131     /// then the prior value is removed as part of the insertion.
set(&mut self, ext: Extension)132     pub fn set(&mut self, ext: Extension) {
133         let mut found = self
134             .0
135             .iter_mut()
136             .find(|e| e.extension_type == ext.extension_type);
137 
138         if let Some(found) = found.take() {
139             *found = ext;
140         } else {
141             self.0.push(ext);
142         }
143     }
144 
145     /// Get a raw [Extension](super::Extension) value based on an
146     /// [ExtensionType](super::ExtensionType).
get(&self, extension_type: ExtensionType) -> Option<Extension>147     pub fn get(&self, extension_type: ExtensionType) -> Option<Extension> {
148         self.0
149             .iter()
150             .find(|e| e.extension_type == extension_type)
151             .cloned()
152     }
153 
154     /// Remove an extension from the list by
155     /// [ExtensionType](super::ExtensionType)
remove(&mut self, ext_type: ExtensionType)156     pub fn remove(&mut self, ext_type: ExtensionType) {
157         self.0.retain(|e| e.extension_type != ext_type)
158     }
159 
160     /// Append another extension list to this one.
161     ///
162     /// If there is already an entry in the list for the same extension type,
163     /// then the existing value is removed.
append(&mut self, others: Self)164     pub fn append(&mut self, others: Self) {
165         self.0.extend(others.0);
166     }
167 }
168 
169 #[cfg(test)]
170 mod tests {
171     use alloc::vec;
172     use alloc::vec::Vec;
173     use assert_matches::assert_matches;
174     use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
175 
176     use crate::extension::{
177         list::ExtensionList, Extension, ExtensionType, MlsCodecExtension, MlsExtension,
178     };
179 
180     #[derive(Debug, Clone, MlsSize, MlsEncode, MlsDecode, PartialEq, Eq)]
181     struct TestExtensionA(u32);
182 
183     #[derive(Debug, Clone, MlsEncode, MlsDecode, MlsSize, PartialEq, Eq)]
184     struct TestExtensionB(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>);
185 
186     #[derive(Debug, Clone, MlsEncode, MlsDecode, MlsSize, PartialEq, Eq)]
187     struct TestExtensionC(u8);
188 
189     impl MlsCodecExtension for TestExtensionA {
extension_type() -> ExtensionType190         fn extension_type() -> ExtensionType {
191             ExtensionType(128)
192         }
193     }
194 
195     impl MlsCodecExtension for TestExtensionB {
extension_type() -> ExtensionType196         fn extension_type() -> ExtensionType {
197             ExtensionType(129)
198         }
199     }
200 
201     impl MlsCodecExtension for TestExtensionC {
extension_type() -> ExtensionType202         fn extension_type() -> ExtensionType {
203             ExtensionType(130)
204         }
205     }
206 
207     #[test]
test_extension_list_get_set_from_get_as()208     fn test_extension_list_get_set_from_get_as() {
209         let mut list = ExtensionList::new();
210 
211         let ext_a = TestExtensionA(0);
212         let ext_b = TestExtensionB(vec![1]);
213 
214         // Add the extensions to the list
215         list.set_from(ext_a.clone()).unwrap();
216         list.set_from(ext_b.clone()).unwrap();
217 
218         assert_eq!(list.len(), 2);
219         assert_eq!(list.get_as::<TestExtensionA>().unwrap(), Some(ext_a));
220         assert_eq!(list.get_as::<TestExtensionB>().unwrap(), Some(ext_b));
221     }
222 
223     #[test]
test_extension_list_get_set()224     fn test_extension_list_get_set() {
225         let mut list = ExtensionList::new();
226 
227         let ext_a = Extension::new(ExtensionType(254), vec![0, 1, 2]);
228         let ext_b = Extension::new(ExtensionType(255), vec![4, 5, 6]);
229 
230         // Add the extensions to the list
231         list.set(ext_a.clone());
232         list.set(ext_b.clone());
233 
234         assert_eq!(list.len(), 2);
235         assert_eq!(list.get(ExtensionType(254)), Some(ext_a));
236         assert_eq!(list.get(ExtensionType(255)), Some(ext_b));
237     }
238 
239     #[test]
extension_list_can_overwrite_values()240     fn extension_list_can_overwrite_values() {
241         let mut list = ExtensionList::new();
242 
243         let ext_1 = TestExtensionA(0);
244         let ext_2 = TestExtensionA(1);
245 
246         list.set_from(ext_1).unwrap();
247         list.set_from(ext_2.clone()).unwrap();
248 
249         assert_eq!(list.get_as::<TestExtensionA>().unwrap(), Some(ext_2));
250     }
251 
252     #[test]
extension_list_will_return_none_for_type_not_stored()253     fn extension_list_will_return_none_for_type_not_stored() {
254         let mut list = ExtensionList::new();
255 
256         assert!(list.get_as::<TestExtensionA>().unwrap().is_none());
257 
258         assert!(list
259             .get(<TestExtensionA as MlsCodecExtension>::extension_type())
260             .is_none());
261 
262         list.set_from(TestExtensionA(1)).unwrap();
263 
264         assert!(list.get_as::<TestExtensionB>().unwrap().is_none());
265 
266         assert!(list
267             .get(<TestExtensionB as MlsCodecExtension>::extension_type())
268             .is_none());
269     }
270 
271     #[test]
test_extension_list_has_ext()272     fn test_extension_list_has_ext() {
273         let mut list = ExtensionList::new();
274 
275         let ext = TestExtensionA(255);
276 
277         list.set_from(ext).unwrap();
278 
279         assert!(list.has_extension(<TestExtensionA as MlsCodecExtension>::extension_type()));
280         assert!(!list.has_extension(42.into()));
281     }
282 
283     #[derive(MlsEncode, MlsSize)]
284     struct ExtensionsVec(Vec<Extension>);
285 
286     #[test]
extension_list_is_serialized_like_a_sequence_of_extensions()287     fn extension_list_is_serialized_like_a_sequence_of_extensions() {
288         let extension_vec = vec![
289             Extension::new(ExtensionType(128), vec![0, 1, 2, 3]),
290             Extension::new(ExtensionType(129), vec![1, 2, 3, 4]),
291         ];
292 
293         let extension_list: ExtensionList = ExtensionList::from(extension_vec.clone());
294 
295         assert_eq!(
296             ExtensionsVec(extension_vec).mls_encode_to_vec().unwrap(),
297             extension_list.mls_encode_to_vec().unwrap(),
298         );
299     }
300 
301     #[test]
deserializing_extension_list_fails_on_duplicate_extension()302     fn deserializing_extension_list_fails_on_duplicate_extension() {
303         let extensions = ExtensionsVec(vec![
304             TestExtensionA(1).into_extension().unwrap(),
305             TestExtensionA(2).into_extension().unwrap(),
306         ]);
307 
308         let serialized_extensions = extensions.mls_encode_to_vec().unwrap();
309 
310         assert_matches!(
311             ExtensionList::mls_decode(&mut &*serialized_extensions),
312             Err(mls_rs_codec::Error::Custom(_))
313         );
314     }
315 
316     #[test]
extension_list_equality_does_not_consider_order()317     fn extension_list_equality_does_not_consider_order() {
318         let extensions = [
319             TestExtensionA(33).into_extension().unwrap(),
320             TestExtensionC(34).into_extension().unwrap(),
321         ];
322 
323         let a = extensions.iter().cloned().collect::<ExtensionList>();
324         let b = extensions.iter().rev().cloned().collect::<ExtensionList>();
325 
326         assert_eq!(a, b);
327     }
328 
329     #[test]
extending_extension_list_maintains_extension_uniqueness()330     fn extending_extension_list_maintains_extension_uniqueness() {
331         let mut list = ExtensionList::new();
332         list.set_from(TestExtensionA(33)).unwrap();
333         list.set_from(TestExtensionC(34)).unwrap();
334         list.extend([
335             TestExtensionA(35).into_extension().unwrap(),
336             TestExtensionB(vec![36]).into_extension().unwrap(),
337             TestExtensionA(37).into_extension().unwrap(),
338         ]);
339 
340         let expected = ExtensionList(vec![
341             TestExtensionA(37).into_extension().unwrap(),
342             TestExtensionB(vec![36]).into_extension().unwrap(),
343             TestExtensionC(34).into_extension().unwrap(),
344         ]);
345 
346         assert_eq!(list, expected);
347     }
348 
349     #[test]
extension_list_from_vec_maintains_extension_uniqueness()350     fn extension_list_from_vec_maintains_extension_uniqueness() {
351         let list = ExtensionList::from(vec![
352             TestExtensionA(33).into_extension().unwrap(),
353             TestExtensionC(34).into_extension().unwrap(),
354             TestExtensionA(35).into_extension().unwrap(),
355         ]);
356 
357         let expected = ExtensionList(vec![
358             TestExtensionA(35).into_extension().unwrap(),
359             TestExtensionC(34).into_extension().unwrap(),
360         ]);
361 
362         assert_eq!(list, expected);
363     }
364 }
365