1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use mls_rs_core::{crypto::CipherSuiteProvider, extension::ExtensionList, group::Capabilities};
6 
7 use crate::{
8     client::MlsError,
9     group::{GroupInfo, NewMemberInfo},
10     key_package::KeyPackage,
11     tree_kem::leaf_node::LeafNode,
12 };
13 
14 impl LeafNode {
ungreased_capabilities(&self) -> Capabilities15     pub fn ungreased_capabilities(&self) -> Capabilities {
16         let mut capabilitites = self.capabilities.clone();
17         grease_functions::ungrease(&mut capabilitites.cipher_suites);
18         grease_functions::ungrease(&mut capabilitites.extensions);
19         grease_functions::ungrease(&mut capabilitites.proposals);
20         grease_functions::ungrease(&mut capabilitites.credentials);
21         capabilitites
22     }
23 
ungreased_extensions(&self) -> ExtensionList24     pub fn ungreased_extensions(&self) -> ExtensionList {
25         let mut extensions = self.extensions.clone();
26         grease_functions::ungrease_extensions(&mut extensions);
27         extensions
28     }
29 
grease<P: CipherSuiteProvider>(&mut self, cs: &P) -> Result<(), MlsError>30     pub fn grease<P: CipherSuiteProvider>(&mut self, cs: &P) -> Result<(), MlsError> {
31         grease_functions::grease(&mut self.capabilities.cipher_suites, cs)?;
32         grease_functions::grease(&mut self.capabilities.proposals, cs)?;
33         grease_functions::grease(&mut self.capabilities.credentials, cs)?;
34 
35         let mut new_extensions = grease_functions::grease_extensions(&mut self.extensions, cs)?;
36         self.capabilities.extensions.append(&mut new_extensions);
37 
38         Ok(())
39     }
40 }
41 
42 impl KeyPackage {
grease<P: CipherSuiteProvider>(&mut self, cs: &P) -> Result<(), MlsError>43     pub fn grease<P: CipherSuiteProvider>(&mut self, cs: &P) -> Result<(), MlsError> {
44         grease_functions::grease_extensions(&mut self.extensions, cs).map(|_| ())
45     }
46 
ungreased_extensions(&self) -> ExtensionList47     pub fn ungreased_extensions(&self) -> ExtensionList {
48         let mut extensions = self.extensions.clone();
49         grease_functions::ungrease_extensions(&mut extensions);
50         extensions
51     }
52 }
53 
54 impl GroupInfo {
grease<P: CipherSuiteProvider>(&mut self, cs: &P) -> Result<(), MlsError>55     pub fn grease<P: CipherSuiteProvider>(&mut self, cs: &P) -> Result<(), MlsError> {
56         grease_functions::grease_extensions(&mut self.extensions, cs).map(|_| ())
57     }
58 }
59 
60 impl NewMemberInfo {
ungrease(&mut self)61     pub fn ungrease(&mut self) {
62         grease_functions::ungrease_extensions(&mut self.group_info_extensions)
63     }
64 }
65 
66 #[cfg(feature = "grease")]
67 mod grease_functions {
68     use core::ops::Deref;
69 
70     use mls_rs_core::{
71         crypto::CipherSuiteProvider,
72         error::IntoAnyError,
73         extension::{Extension, ExtensionList, ExtensionType},
74     };
75 
76     use super::MlsError;
77 
78     pub const GREASE_VALUES: &[u16] = &[
79         0x0A0A, 0x1A1A, 0x2A2A, 0x3A3A, 0x4A4A, 0x5A5A, 0x6A6A, 0x7A7A, 0x8A8A, 0x9A9A, 0xAAAA,
80         0xBABA, 0xCACA, 0xDADA, 0xEAEA,
81     ];
82 
grease<T: From<u16>, P: CipherSuiteProvider>( array: &mut Vec<T>, cs: &P, ) -> Result<(), MlsError>83     pub fn grease<T: From<u16>, P: CipherSuiteProvider>(
84         array: &mut Vec<T>,
85         cs: &P,
86     ) -> Result<(), MlsError> {
87         array.push(random_grease_value(cs)?.into());
88         Ok(())
89     }
90 
grease_extensions<P: CipherSuiteProvider>( extensions: &mut ExtensionList, cs: &P, ) -> Result<Vec<ExtensionType>, MlsError>91     pub fn grease_extensions<P: CipherSuiteProvider>(
92         extensions: &mut ExtensionList,
93         cs: &P,
94     ) -> Result<Vec<ExtensionType>, MlsError> {
95         let grease_value = random_grease_value(cs)?;
96         extensions.set(Extension::new(grease_value.into(), vec![]));
97         Ok(vec![grease_value.into()])
98     }
99 
random_grease_value<P: CipherSuiteProvider>(cs: &P) -> Result<u16, MlsError>100     fn random_grease_value<P: CipherSuiteProvider>(cs: &P) -> Result<u16, MlsError> {
101         let index = cs
102             .random_bytes_vec(1)
103             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?[0];
104 
105         Ok(GREASE_VALUES[index as usize % GREASE_VALUES.len()])
106     }
107 
ungrease<T: Deref<Target = u16>>(array: &mut Vec<T>)108     pub fn ungrease<T: Deref<Target = u16>>(array: &mut Vec<T>) {
109         array.retain(|x| !GREASE_VALUES.contains(&**x));
110     }
111 
ungrease_extensions(extensions: &mut ExtensionList)112     pub fn ungrease_extensions(extensions: &mut ExtensionList) {
113         for e in GREASE_VALUES {
114             extensions.remove((*e).into())
115         }
116     }
117 }
118 
119 #[cfg(not(feature = "grease"))]
120 mod grease_functions {
121     use core::ops::Deref;
122 
123     use alloc::vec::Vec;
124 
125     use mls_rs_core::{
126         crypto::CipherSuiteProvider,
127         extension::{ExtensionList, ExtensionType},
128     };
129 
130     use super::MlsError;
131 
grease<T: From<u16>, P: CipherSuiteProvider>( _array: &mut [T], _cs: &P, ) -> Result<(), MlsError>132     pub fn grease<T: From<u16>, P: CipherSuiteProvider>(
133         _array: &mut [T],
134         _cs: &P,
135     ) -> Result<(), MlsError> {
136         Ok(())
137     }
138 
grease_extensions<P: CipherSuiteProvider>( _extensions: &mut ExtensionList, _cs: &P, ) -> Result<Vec<ExtensionType>, MlsError>139     pub fn grease_extensions<P: CipherSuiteProvider>(
140         _extensions: &mut ExtensionList,
141         _cs: &P,
142     ) -> Result<Vec<ExtensionType>, MlsError> {
143         Ok(Vec::new())
144     }
145 
ungrease<T: Deref<Target = u16>>(_array: &mut [T])146     pub fn ungrease<T: Deref<Target = u16>>(_array: &mut [T]) {}
147 
ungrease_extensions(_extensions: &mut ExtensionList)148     pub fn ungrease_extensions(_extensions: &mut ExtensionList) {}
149 }
150 
151 #[cfg(all(test, feature = "grease"))]
152 mod tests {
153     #[cfg(target_arch = "wasm32")]
154     use wasm_bindgen_test::wasm_bindgen_test as test;
155 
156     use std::ops::Deref;
157 
158     use mls_rs_core::extension::ExtensionList;
159 
160     use crate::{
161         client::test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
162         group::test_utils::test_group,
163     };
164 
165     use super::grease_functions::GREASE_VALUES;
166 
167     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
key_package_is_greased()168     async fn key_package_is_greased() {
169         let key_pkg = test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice")
170             .await
171             .1
172             .into_key_package()
173             .unwrap();
174 
175         assert!(is_ext_greased(&key_pkg.extensions));
176         assert!(is_ext_greased(&key_pkg.leaf_node.extensions));
177         assert!(is_greased(&key_pkg.leaf_node.capabilities.cipher_suites));
178         assert!(is_greased(&key_pkg.leaf_node.capabilities.extensions));
179         assert!(is_greased(&key_pkg.leaf_node.capabilities.proposals));
180         assert!(is_greased(&key_pkg.leaf_node.capabilities.credentials));
181 
182         assert!(!is_greased(
183             &key_pkg.leaf_node.capabilities.protocol_versions
184         ));
185     }
186 
187     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
group_info_is_greased()188     async fn group_info_is_greased() {
189         let group_info = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE)
190             .await
191             .group
192             .group_info_message_allowing_ext_commit(false)
193             .await
194             .unwrap()
195             .into_group_info()
196             .unwrap();
197 
198         assert!(is_ext_greased(&group_info.extensions));
199     }
200 
201     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
public_api_is_not_greased()202     async fn public_api_is_not_greased() {
203         let member = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE)
204             .await
205             .group
206             .roster()
207             .member_with_index(0)
208             .unwrap();
209 
210         assert!(!is_ext_greased(member.extensions()));
211         assert!(!is_greased(member.capabilities().protocol_versions()));
212         assert!(!is_greased(member.capabilities().cipher_suites()));
213         assert!(!is_greased(member.capabilities().extensions()));
214         assert!(!is_greased(member.capabilities().proposals()));
215         assert!(!is_greased(member.capabilities().credentials()));
216     }
217 
is_greased<T: Deref<Target = u16>>(list: &[T]) -> bool218     fn is_greased<T: Deref<Target = u16>>(list: &[T]) -> bool {
219         list.iter().any(|v| GREASE_VALUES.contains(v))
220     }
221 
is_ext_greased(extensions: &ExtensionList) -> bool222     fn is_ext_greased(extensions: &ExtensionList) -> bool {
223         extensions
224             .iter()
225             .any(|ext| GREASE_VALUES.contains(&*ext.extension_type()))
226     }
227 }
228