1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use crate::cipher_suite::CipherSuite;
6 use crate::client::MlsError;
7 use crate::crypto::HpkePublicKey;
8 use crate::hash_reference::HashReference;
9 use crate::identity::SigningIdentity;
10 use crate::protocol_version::ProtocolVersion;
11 use crate::signer::Signable;
12 use crate::tree_kem::leaf_node::{LeafNode, LeafNodeSource};
13 use crate::CipherSuiteProvider;
14 use alloc::vec::Vec;
15 use core::{
16     fmt::{self, Debug},
17     ops::Deref,
18 };
19 use mls_rs_codec::MlsDecode;
20 use mls_rs_codec::MlsEncode;
21 use mls_rs_codec::MlsSize;
22 use mls_rs_core::extension::ExtensionList;
23 
24 mod validator;
25 pub(crate) use validator::*;
26 
27 pub(crate) mod generator;
28 pub(crate) use generator::*;
29 
30 #[non_exhaustive]
31 #[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq)]
32 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
33 #[cfg_attr(
34     all(feature = "ffi", not(test)),
35     safer_ffi_gen::ffi_type(clone, opaque)
36 )]
37 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
38 pub struct KeyPackage {
39     pub version: ProtocolVersion,
40     pub cipher_suite: CipherSuite,
41     pub hpke_init_key: HpkePublicKey,
42     pub(crate) leaf_node: LeafNode,
43     pub extensions: ExtensionList,
44     #[mls_codec(with = "mls_rs_codec::byte_vec")]
45     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
46     pub signature: Vec<u8>,
47 }
48 
49 impl Debug for KeyPackage {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result50     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51         f.debug_struct("KeyPackage")
52             .field("version", &self.version)
53             .field("cipher_suite", &self.cipher_suite)
54             .field("hpke_init_key", &self.hpke_init_key)
55             .field("leaf_node", &self.leaf_node)
56             .field("extensions", &self.extensions)
57             .field(
58                 "signature",
59                 &mls_rs_core::debug::pretty_bytes(&self.signature),
60             )
61             .finish()
62     }
63 }
64 
65 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, MlsSize, MlsEncode, MlsDecode)]
66 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
67 #[cfg_attr(
68     all(feature = "ffi", not(test)),
69     safer_ffi_gen::ffi_type(clone, opaque)
70 )]
71 pub struct KeyPackageRef(HashReference);
72 
73 impl Deref for KeyPackageRef {
74     type Target = [u8];
75 
deref(&self) -> &Self::Target76     fn deref(&self) -> &Self::Target {
77         &self.0
78     }
79 }
80 
81 impl From<Vec<u8>> for KeyPackageRef {
from(v: Vec<u8>) -> Self82     fn from(v: Vec<u8>) -> Self {
83         Self(HashReference::from(v))
84     }
85 }
86 
87 #[derive(MlsSize, MlsEncode)]
88 struct KeyPackageData<'a> {
89     pub version: ProtocolVersion,
90     pub cipher_suite: CipherSuite,
91     #[mls_codec(with = "mls_rs_codec::byte_vec")]
92     pub hpke_init_key: &'a HpkePublicKey,
93     pub leaf_node: &'a LeafNode,
94     pub extensions: &'a ExtensionList,
95 }
96 
97 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
98 impl KeyPackage {
99     #[cfg(feature = "ffi")]
version(&self) -> ProtocolVersion100     pub fn version(&self) -> ProtocolVersion {
101         self.version
102     }
103 
104     #[cfg(feature = "ffi")]
cipher_suite(&self) -> CipherSuite105     pub fn cipher_suite(&self) -> CipherSuite {
106         self.cipher_suite
107     }
108 
signing_identity(&self) -> &SigningIdentity109     pub fn signing_identity(&self) -> &SigningIdentity {
110         &self.leaf_node.signing_identity
111     }
112 
113     #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
114     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
to_reference<CP: CipherSuiteProvider>( &self, cipher_suite_provider: &CP, ) -> Result<KeyPackageRef, MlsError>115     pub async fn to_reference<CP: CipherSuiteProvider>(
116         &self,
117         cipher_suite_provider: &CP,
118     ) -> Result<KeyPackageRef, MlsError> {
119         if cipher_suite_provider.cipher_suite() != self.cipher_suite {
120             return Err(MlsError::CipherSuiteMismatch);
121         }
122 
123         Ok(KeyPackageRef(
124             HashReference::compute(
125                 &self.mls_encode_to_vec()?,
126                 b"MLS 1.0 KeyPackage Reference",
127                 cipher_suite_provider,
128             )
129             .await?,
130         ))
131     }
132 
expiration(&self) -> Result<u64, MlsError>133     pub fn expiration(&self) -> Result<u64, MlsError> {
134         if let LeafNodeSource::KeyPackage(lifetime) = &self.leaf_node.leaf_node_source {
135             Ok(lifetime.not_after)
136         } else {
137             Err(MlsError::InvalidLeafNodeSource)
138         }
139     }
140 }
141 
142 impl<'a> Signable<'a> for KeyPackage {
143     const SIGN_LABEL: &'static str = "KeyPackageTBS";
144 
145     type SigningContext = ();
146 
signature(&self) -> &[u8]147     fn signature(&self) -> &[u8] {
148         &self.signature
149     }
150 
signable_content( &self, _context: &Self::SigningContext, ) -> Result<Vec<u8>, mls_rs_codec::Error>151     fn signable_content(
152         &self,
153         _context: &Self::SigningContext,
154     ) -> Result<Vec<u8>, mls_rs_codec::Error> {
155         KeyPackageData {
156             version: self.version,
157             cipher_suite: self.cipher_suite,
158             hpke_init_key: &self.hpke_init_key,
159             leaf_node: &self.leaf_node,
160             extensions: &self.extensions,
161         }
162         .mls_encode_to_vec()
163     }
164 
write_signature(&mut self, signature: Vec<u8>)165     fn write_signature(&mut self, signature: Vec<u8>) {
166         self.signature = signature
167     }
168 }
169 
170 #[cfg(test)]
171 pub(crate) mod test_utils {
172     use super::*;
173     use crate::{
174         crypto::test_utils::test_cipher_suite_provider,
175         group::framing::MlsMessagePayload,
176         identity::basic::BasicIdentityProvider,
177         identity::test_utils::get_test_signing_identity,
178         tree_kem::{leaf_node::test_utils::get_test_capabilities, Lifetime},
179         MlsMessage,
180     };
181 
182     use mls_rs_core::crypto::SignatureSecretKey;
183 
184     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_key_package( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, id: &str, ) -> KeyPackage185     pub(crate) async fn test_key_package(
186         protocol_version: ProtocolVersion,
187         cipher_suite: CipherSuite,
188         id: &str,
189     ) -> KeyPackage {
190         test_key_package_with_signer(protocol_version, cipher_suite, id)
191             .await
192             .0
193     }
194 
195     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_key_package_with_signer( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, id: &str, ) -> (KeyPackage, SignatureSecretKey)196     pub(crate) async fn test_key_package_with_signer(
197         protocol_version: ProtocolVersion,
198         cipher_suite: CipherSuite,
199         id: &str,
200     ) -> (KeyPackage, SignatureSecretKey) {
201         let (signing_identity, secret_key) =
202             get_test_signing_identity(cipher_suite, id.as_bytes()).await;
203 
204         let generator = KeyPackageGenerator {
205             protocol_version,
206             cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
207             signing_identity: &signing_identity,
208             signing_key: &secret_key,
209             identity_provider: &BasicIdentityProvider,
210         };
211 
212         let key_package = generator
213             .generate(
214                 Lifetime::years(1).unwrap(),
215                 get_test_capabilities(),
216                 ExtensionList::default(),
217                 ExtensionList::default(),
218             )
219             .await
220             .unwrap()
221             .key_package;
222 
223         (key_package, secret_key)
224     }
225 
226     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_key_package_message( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, id: &str, ) -> MlsMessage227     pub(crate) async fn test_key_package_message(
228         protocol_version: ProtocolVersion,
229         cipher_suite: CipherSuite,
230         id: &str,
231     ) -> MlsMessage {
232         MlsMessage::new(
233             protocol_version,
234             MlsMessagePayload::KeyPackage(
235                 test_key_package(protocol_version, cipher_suite, id).await,
236             ),
237         )
238     }
239 }
240 
241 #[cfg(test)]
242 mod tests {
243     use crate::{
244         client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
245         crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
246     };
247 
248     use super::{test_utils::test_key_package, *};
249     use alloc::format;
250     use assert_matches::assert_matches;
251 
252     #[derive(serde::Deserialize, serde::Serialize)]
253     struct TestCase {
254         cipher_suite: u16,
255         #[serde(with = "hex::serde")]
256         input: Vec<u8>,
257         #[serde(with = "hex::serde")]
258         output: Vec<u8>,
259     }
260 
261     impl TestCase {
262         #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
263         #[cfg_attr(coverage_nightly, coverage(off))]
generate() -> Vec<TestCase>264         async fn generate() -> Vec<TestCase> {
265             let mut test_cases = Vec::new();
266 
267             for (i, (protocol_version, cipher_suite)) in ProtocolVersion::all()
268                 .flat_map(|p| CipherSuite::all().map(move |cs| (p, cs)))
269                 .enumerate()
270             {
271                 let pkg =
272                     test_key_package(protocol_version, cipher_suite, &format!("alice{i}")).await;
273 
274                 let pkg_ref = pkg
275                     .to_reference(&test_cipher_suite_provider(cipher_suite))
276                     .await
277                     .unwrap();
278 
279                 let case = TestCase {
280                     cipher_suite: cipher_suite.into(),
281                     input: pkg.mls_encode_to_vec().unwrap(),
282                     output: pkg_ref.to_vec(),
283                 };
284 
285                 test_cases.push(case);
286             }
287 
288             test_cases
289         }
290     }
291 
292     #[cfg(mls_build_async)]
load_test_cases() -> Vec<TestCase>293     async fn load_test_cases() -> Vec<TestCase> {
294         load_test_case_json!(key_package_ref, TestCase::generate().await)
295     }
296 
297     #[cfg(not(mls_build_async))]
load_test_cases() -> Vec<TestCase>298     fn load_test_cases() -> Vec<TestCase> {
299         load_test_case_json!(key_package_ref, TestCase::generate())
300     }
301 
302     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_key_package_ref()303     async fn test_key_package_ref() {
304         let cases = load_test_cases().await;
305 
306         for one_case in cases {
307             let Some(provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else {
308                 continue;
309             };
310 
311             let key_package = KeyPackage::mls_decode(&mut one_case.input.as_slice()).unwrap();
312 
313             let key_package_ref = key_package.to_reference(&provider).await.unwrap();
314 
315             let expected_out = KeyPackageRef::from(one_case.output);
316             assert_eq!(expected_out, key_package_ref);
317         }
318     }
319 
320     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
key_package_ref_fails_invalid_cipher_suite()321     async fn key_package_ref_fails_invalid_cipher_suite() {
322         let key_package = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "test").await;
323 
324         for another_cipher_suite in CipherSuite::all().filter(|cs| cs != &TEST_CIPHER_SUITE) {
325             if let Some(cs) = try_test_cipher_suite_provider(*another_cipher_suite) {
326                 let res = key_package.to_reference(&cs).await;
327 
328                 assert_matches!(res, Err(MlsError::CipherSuiteMismatch));
329             }
330         }
331     }
332 }
333