1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use crate::cipher_suite::CipherSuite; 6 use crate::client::MlsError; 7 use crate::crypto::HpkePublicKey; 8 use crate::hash_reference::HashReference; 9 use crate::identity::SigningIdentity; 10 use crate::protocol_version::ProtocolVersion; 11 use crate::signer::Signable; 12 use crate::tree_kem::leaf_node::{LeafNode, LeafNodeSource}; 13 use crate::CipherSuiteProvider; 14 use alloc::vec::Vec; 15 use core::{ 16 fmt::{self, Debug}, 17 ops::Deref, 18 }; 19 use mls_rs_codec::MlsDecode; 20 use mls_rs_codec::MlsEncode; 21 use mls_rs_codec::MlsSize; 22 use mls_rs_core::extension::ExtensionList; 23 24 mod validator; 25 pub(crate) use validator::*; 26 27 pub(crate) mod generator; 28 pub(crate) use generator::*; 29 30 #[non_exhaustive] 31 #[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq)] 32 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 33 #[cfg_attr( 34 all(feature = "ffi", not(test)), 35 safer_ffi_gen::ffi_type(clone, opaque) 36 )] 37 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 38 pub struct KeyPackage { 39 pub version: ProtocolVersion, 40 pub cipher_suite: CipherSuite, 41 pub hpke_init_key: HpkePublicKey, 42 pub(crate) leaf_node: LeafNode, 43 pub extensions: ExtensionList, 44 #[mls_codec(with = "mls_rs_codec::byte_vec")] 45 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))] 46 pub signature: Vec<u8>, 47 } 48 49 impl Debug for KeyPackage { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 51 f.debug_struct("KeyPackage") 52 .field("version", &self.version) 53 .field("cipher_suite", &self.cipher_suite) 54 .field("hpke_init_key", &self.hpke_init_key) 55 .field("leaf_node", &self.leaf_node) 56 .field("extensions", &self.extensions) 57 .field( 58 "signature", 59 &mls_rs_core::debug::pretty_bytes(&self.signature), 60 ) 61 .finish() 62 } 63 } 64 65 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, MlsSize, MlsEncode, MlsDecode)] 66 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 67 #[cfg_attr( 68 all(feature = "ffi", not(test)), 69 safer_ffi_gen::ffi_type(clone, opaque) 70 )] 71 pub struct KeyPackageRef(HashReference); 72 73 impl Deref for KeyPackageRef { 74 type Target = [u8]; 75 deref(&self) -> &Self::Target76 fn deref(&self) -> &Self::Target { 77 &self.0 78 } 79 } 80 81 impl From<Vec<u8>> for KeyPackageRef { from(v: Vec<u8>) -> Self82 fn from(v: Vec<u8>) -> Self { 83 Self(HashReference::from(v)) 84 } 85 } 86 87 #[derive(MlsSize, MlsEncode)] 88 struct KeyPackageData<'a> { 89 pub version: ProtocolVersion, 90 pub cipher_suite: CipherSuite, 91 #[mls_codec(with = "mls_rs_codec::byte_vec")] 92 pub hpke_init_key: &'a HpkePublicKey, 93 pub leaf_node: &'a LeafNode, 94 pub extensions: &'a ExtensionList, 95 } 96 97 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)] 98 impl KeyPackage { 99 #[cfg(feature = "ffi")] version(&self) -> ProtocolVersion100 pub fn version(&self) -> ProtocolVersion { 101 self.version 102 } 103 104 #[cfg(feature = "ffi")] cipher_suite(&self) -> CipherSuite105 pub fn cipher_suite(&self) -> CipherSuite { 106 self.cipher_suite 107 } 108 signing_identity(&self) -> &SigningIdentity109 pub fn signing_identity(&self) -> &SigningIdentity { 110 &self.leaf_node.signing_identity 111 } 112 113 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)] 114 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] to_reference<CP: CipherSuiteProvider>( &self, cipher_suite_provider: &CP, ) -> Result<KeyPackageRef, MlsError>115 pub async fn to_reference<CP: CipherSuiteProvider>( 116 &self, 117 cipher_suite_provider: &CP, 118 ) -> Result<KeyPackageRef, MlsError> { 119 if cipher_suite_provider.cipher_suite() != self.cipher_suite { 120 return Err(MlsError::CipherSuiteMismatch); 121 } 122 123 Ok(KeyPackageRef( 124 HashReference::compute( 125 &self.mls_encode_to_vec()?, 126 b"MLS 1.0 KeyPackage Reference", 127 cipher_suite_provider, 128 ) 129 .await?, 130 )) 131 } 132 expiration(&self) -> Result<u64, MlsError>133 pub fn expiration(&self) -> Result<u64, MlsError> { 134 if let LeafNodeSource::KeyPackage(lifetime) = &self.leaf_node.leaf_node_source { 135 Ok(lifetime.not_after) 136 } else { 137 Err(MlsError::InvalidLeafNodeSource) 138 } 139 } 140 } 141 142 impl<'a> Signable<'a> for KeyPackage { 143 const SIGN_LABEL: &'static str = "KeyPackageTBS"; 144 145 type SigningContext = (); 146 signature(&self) -> &[u8]147 fn signature(&self) -> &[u8] { 148 &self.signature 149 } 150 signable_content( &self, _context: &Self::SigningContext, ) -> Result<Vec<u8>, mls_rs_codec::Error>151 fn signable_content( 152 &self, 153 _context: &Self::SigningContext, 154 ) -> Result<Vec<u8>, mls_rs_codec::Error> { 155 KeyPackageData { 156 version: self.version, 157 cipher_suite: self.cipher_suite, 158 hpke_init_key: &self.hpke_init_key, 159 leaf_node: &self.leaf_node, 160 extensions: &self.extensions, 161 } 162 .mls_encode_to_vec() 163 } 164 write_signature(&mut self, signature: Vec<u8>)165 fn write_signature(&mut self, signature: Vec<u8>) { 166 self.signature = signature 167 } 168 } 169 170 #[cfg(test)] 171 pub(crate) mod test_utils { 172 use super::*; 173 use crate::{ 174 crypto::test_utils::test_cipher_suite_provider, 175 group::framing::MlsMessagePayload, 176 identity::basic::BasicIdentityProvider, 177 identity::test_utils::get_test_signing_identity, 178 tree_kem::{leaf_node::test_utils::get_test_capabilities, Lifetime}, 179 MlsMessage, 180 }; 181 182 use mls_rs_core::crypto::SignatureSecretKey; 183 184 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] test_key_package( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, id: &str, ) -> KeyPackage185 pub(crate) async fn test_key_package( 186 protocol_version: ProtocolVersion, 187 cipher_suite: CipherSuite, 188 id: &str, 189 ) -> KeyPackage { 190 test_key_package_with_signer(protocol_version, cipher_suite, id) 191 .await 192 .0 193 } 194 195 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] test_key_package_with_signer( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, id: &str, ) -> (KeyPackage, SignatureSecretKey)196 pub(crate) async fn test_key_package_with_signer( 197 protocol_version: ProtocolVersion, 198 cipher_suite: CipherSuite, 199 id: &str, 200 ) -> (KeyPackage, SignatureSecretKey) { 201 let (signing_identity, secret_key) = 202 get_test_signing_identity(cipher_suite, id.as_bytes()).await; 203 204 let generator = KeyPackageGenerator { 205 protocol_version, 206 cipher_suite_provider: &test_cipher_suite_provider(cipher_suite), 207 signing_identity: &signing_identity, 208 signing_key: &secret_key, 209 identity_provider: &BasicIdentityProvider, 210 }; 211 212 let key_package = generator 213 .generate( 214 Lifetime::years(1).unwrap(), 215 get_test_capabilities(), 216 ExtensionList::default(), 217 ExtensionList::default(), 218 ) 219 .await 220 .unwrap() 221 .key_package; 222 223 (key_package, secret_key) 224 } 225 226 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] test_key_package_message( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, id: &str, ) -> MlsMessage227 pub(crate) async fn test_key_package_message( 228 protocol_version: ProtocolVersion, 229 cipher_suite: CipherSuite, 230 id: &str, 231 ) -> MlsMessage { 232 MlsMessage::new( 233 protocol_version, 234 MlsMessagePayload::KeyPackage( 235 test_key_package(protocol_version, cipher_suite, id).await, 236 ), 237 ) 238 } 239 } 240 241 #[cfg(test)] 242 mod tests { 243 use crate::{ 244 client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION}, 245 crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider}, 246 }; 247 248 use super::{test_utils::test_key_package, *}; 249 use alloc::format; 250 use assert_matches::assert_matches; 251 252 #[derive(serde::Deserialize, serde::Serialize)] 253 struct TestCase { 254 cipher_suite: u16, 255 #[serde(with = "hex::serde")] 256 input: Vec<u8>, 257 #[serde(with = "hex::serde")] 258 output: Vec<u8>, 259 } 260 261 impl TestCase { 262 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] 263 #[cfg_attr(coverage_nightly, coverage(off))] generate() -> Vec<TestCase>264 async fn generate() -> Vec<TestCase> { 265 let mut test_cases = Vec::new(); 266 267 for (i, (protocol_version, cipher_suite)) in ProtocolVersion::all() 268 .flat_map(|p| CipherSuite::all().map(move |cs| (p, cs))) 269 .enumerate() 270 { 271 let pkg = 272 test_key_package(protocol_version, cipher_suite, &format!("alice{i}")).await; 273 274 let pkg_ref = pkg 275 .to_reference(&test_cipher_suite_provider(cipher_suite)) 276 .await 277 .unwrap(); 278 279 let case = TestCase { 280 cipher_suite: cipher_suite.into(), 281 input: pkg.mls_encode_to_vec().unwrap(), 282 output: pkg_ref.to_vec(), 283 }; 284 285 test_cases.push(case); 286 } 287 288 test_cases 289 } 290 } 291 292 #[cfg(mls_build_async)] load_test_cases() -> Vec<TestCase>293 async fn load_test_cases() -> Vec<TestCase> { 294 load_test_case_json!(key_package_ref, TestCase::generate().await) 295 } 296 297 #[cfg(not(mls_build_async))] load_test_cases() -> Vec<TestCase>298 fn load_test_cases() -> Vec<TestCase> { 299 load_test_case_json!(key_package_ref, TestCase::generate()) 300 } 301 302 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_key_package_ref()303 async fn test_key_package_ref() { 304 let cases = load_test_cases().await; 305 306 for one_case in cases { 307 let Some(provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else { 308 continue; 309 }; 310 311 let key_package = KeyPackage::mls_decode(&mut one_case.input.as_slice()).unwrap(); 312 313 let key_package_ref = key_package.to_reference(&provider).await.unwrap(); 314 315 let expected_out = KeyPackageRef::from(one_case.output); 316 assert_eq!(expected_out, key_package_ref); 317 } 318 } 319 320 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] key_package_ref_fails_invalid_cipher_suite()321 async fn key_package_ref_fails_invalid_cipher_suite() { 322 let key_package = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "test").await; 323 324 for another_cipher_suite in CipherSuite::all().filter(|cs| cs != &TEST_CIPHER_SUITE) { 325 if let Some(cs) = try_test_cipher_suite_provider(*another_cipher_suite) { 326 let res = key_package.to_reference(&cs).await; 327 328 assert_matches!(res, Err(MlsError::CipherSuiteMismatch)); 329 } 330 } 331 } 332 } 333