1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use alloc::vec::Vec; 6 use core::fmt::{self, Debug}; 7 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 8 use mls_rs_core::extension::{ExtensionType, MlsCodecExtension}; 9 10 use mls_rs_core::{group::ProposalType, identity::CredentialType}; 11 12 #[cfg(feature = "by_ref_proposal")] 13 use mls_rs_core::{ 14 extension::ExtensionList, 15 identity::{IdentityProvider, SigningIdentity}, 16 time::MlsTime, 17 }; 18 19 use crate::group::ExportedTree; 20 21 use mls_rs_core::crypto::HpkePublicKey; 22 23 /// Application specific identifier. 24 /// 25 /// A custom application level identifier that can be optionally stored 26 /// within the `leaf_node_extensions` of a group [Member](crate::group::Member). 27 #[cfg_attr( 28 all(feature = "ffi", not(test)), 29 safer_ffi_gen::ffi_type(clone, opaque) 30 )] 31 #[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)] 32 pub struct ApplicationIdExt { 33 /// Application level identifier presented by this extension. 34 #[mls_codec(with = "mls_rs_codec::byte_vec")] 35 pub identifier: Vec<u8>, 36 } 37 38 impl Debug for ApplicationIdExt { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 40 f.debug_struct("ApplicationIdExt") 41 .field( 42 "identifier", 43 &mls_rs_core::debug::pretty_bytes(&self.identifier), 44 ) 45 .finish() 46 } 47 } 48 49 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)] 50 impl ApplicationIdExt { 51 /// Create a new application level identifier extension. new(identifier: Vec<u8>) -> Self52 pub fn new(identifier: Vec<u8>) -> Self { 53 ApplicationIdExt { identifier } 54 } 55 56 /// Get the application level identifier presented by this extension. 57 #[cfg(feature = "ffi")] identifier(&self) -> &[u8]58 pub fn identifier(&self) -> &[u8] { 59 &self.identifier 60 } 61 } 62 63 impl MlsCodecExtension for ApplicationIdExt { extension_type() -> ExtensionType64 fn extension_type() -> ExtensionType { 65 ExtensionType::APPLICATION_ID 66 } 67 } 68 69 /// Representation of an MLS ratchet tree. 70 /// 71 /// Used to provide new members 72 /// a copy of the current group state in-band. 73 #[cfg_attr( 74 all(feature = "ffi", not(test)), 75 safer_ffi_gen::ffi_type(clone, opaque) 76 )] 77 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)] 78 pub struct RatchetTreeExt { 79 pub tree_data: ExportedTree<'static>, 80 } 81 82 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)] 83 impl RatchetTreeExt { 84 /// Required custom extension types. 85 #[cfg(feature = "ffi")] tree_data(&self) -> &ExportedTree<'static>86 pub fn tree_data(&self) -> &ExportedTree<'static> { 87 &self.tree_data 88 } 89 } 90 91 impl MlsCodecExtension for RatchetTreeExt { extension_type() -> ExtensionType92 fn extension_type() -> ExtensionType { 93 ExtensionType::RATCHET_TREE 94 } 95 } 96 97 /// Require members to have certain capabilities. 98 /// 99 /// Used within a 100 /// [Group Context Extensions Proposal](crate::group::proposal::Proposal) 101 /// in order to require that all current and future members of a group MUST 102 /// support specific extensions, proposals, or credentials. 103 /// 104 /// # Warning 105 /// 106 /// Extension, proposal, and credential types defined by the MLS RFC and 107 /// provided are considered required by default and should NOT be used 108 /// within this extension. 109 #[cfg_attr( 110 all(feature = "ffi", not(test)), 111 safer_ffi_gen::ffi_type(clone, opaque) 112 )] 113 #[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, Default)] 114 pub struct RequiredCapabilitiesExt { 115 pub extensions: Vec<ExtensionType>, 116 pub proposals: Vec<ProposalType>, 117 pub credentials: Vec<CredentialType>, 118 } 119 120 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)] 121 impl RequiredCapabilitiesExt { 122 /// Create a required capabilities extension. new( extensions: Vec<ExtensionType>, proposals: Vec<ProposalType>, credentials: Vec<CredentialType>, ) -> Self123 pub fn new( 124 extensions: Vec<ExtensionType>, 125 proposals: Vec<ProposalType>, 126 credentials: Vec<CredentialType>, 127 ) -> Self { 128 Self { 129 extensions, 130 proposals, 131 credentials, 132 } 133 } 134 135 /// Required custom extension types. 136 #[cfg(feature = "ffi")] extensions(&self) -> &[ExtensionType]137 pub fn extensions(&self) -> &[ExtensionType] { 138 &self.extensions 139 } 140 141 /// Required custom proposal types. 142 #[cfg(feature = "ffi")] proposals(&self) -> &[ProposalType]143 pub fn proposals(&self) -> &[ProposalType] { 144 &self.proposals 145 } 146 147 /// Required custom credential types. 148 #[cfg(feature = "ffi")] credentials(&self) -> &[CredentialType]149 pub fn credentials(&self) -> &[CredentialType] { 150 &self.credentials 151 } 152 } 153 154 impl MlsCodecExtension for RequiredCapabilitiesExt { extension_type() -> ExtensionType155 fn extension_type() -> ExtensionType { 156 ExtensionType::REQUIRED_CAPABILITIES 157 } 158 } 159 160 /// External public key used for [External Commits](crate::Client::commit_external). 161 /// 162 /// This proposal type is optionally provided as part of a 163 /// [Group Info](crate::group::Group::group_info_message). 164 #[cfg_attr( 165 all(feature = "ffi", not(test)), 166 safer_ffi_gen::ffi_type(clone, opaque) 167 )] 168 #[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)] 169 pub struct ExternalPubExt { 170 /// Public key to be used for an external commit. 171 #[mls_codec(with = "mls_rs_codec::byte_vec")] 172 pub external_pub: HpkePublicKey, 173 } 174 175 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)] 176 impl ExternalPubExt { 177 /// Get the public key to be used for an external commit. 178 #[cfg(feature = "ffi")] external_pub(&self) -> &HpkePublicKey179 pub fn external_pub(&self) -> &HpkePublicKey { 180 &self.external_pub 181 } 182 } 183 184 impl MlsCodecExtension for ExternalPubExt { extension_type() -> ExtensionType185 fn extension_type() -> ExtensionType { 186 ExtensionType::EXTERNAL_PUB 187 } 188 } 189 190 /// Enable proposals by an [ExternalClient](crate::external_client::ExternalClient). 191 #[cfg(feature = "by_ref_proposal")] 192 #[cfg_attr( 193 all(feature = "ffi", not(test)), 194 safer_ffi_gen::ffi_type(clone, opaque) 195 )] 196 #[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)] 197 #[non_exhaustive] 198 pub struct ExternalSendersExt { 199 pub allowed_senders: Vec<SigningIdentity>, 200 } 201 202 #[cfg(feature = "by_ref_proposal")] 203 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)] 204 impl ExternalSendersExt { new(allowed_senders: Vec<SigningIdentity>) -> Self205 pub fn new(allowed_senders: Vec<SigningIdentity>) -> Self { 206 Self { allowed_senders } 207 } 208 209 #[cfg(feature = "ffi")] allowed_senders(&self) -> &[SigningIdentity]210 pub fn allowed_senders(&self) -> &[SigningIdentity] { 211 &self.allowed_senders 212 } 213 214 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] verify_all<I: IdentityProvider>( &self, provider: &I, timestamp: Option<MlsTime>, group_context_extensions: &ExtensionList, ) -> Result<(), I::Error>215 pub(crate) async fn verify_all<I: IdentityProvider>( 216 &self, 217 provider: &I, 218 timestamp: Option<MlsTime>, 219 group_context_extensions: &ExtensionList, 220 ) -> Result<(), I::Error> { 221 for id in self.allowed_senders.iter() { 222 provider 223 .validate_external_sender(id, timestamp, Some(group_context_extensions)) 224 .await?; 225 } 226 227 Ok(()) 228 } 229 } 230 231 #[cfg(feature = "by_ref_proposal")] 232 impl MlsCodecExtension for ExternalSendersExt { extension_type() -> ExtensionType233 fn extension_type() -> ExtensionType { 234 ExtensionType::EXTERNAL_SENDERS 235 } 236 } 237 238 #[cfg(test)] 239 mod tests { 240 use super::*; 241 242 use crate::tree_kem::node::NodeVec; 243 #[cfg(feature = "by_ref_proposal")] 244 use crate::{ 245 client::test_utils::TEST_CIPHER_SUITE, identity::test_utils::get_test_signing_identity, 246 }; 247 248 use mls_rs_core::extension::MlsExtension; 249 250 use mls_rs_core::identity::BasicCredential; 251 252 use alloc::vec; 253 254 #[cfg(target_arch = "wasm32")] 255 use wasm_bindgen_test::wasm_bindgen_test as test; 256 257 #[test] test_application_id_extension()258 fn test_application_id_extension() { 259 let test_id = vec![0u8; 32]; 260 let test_extension = ApplicationIdExt { 261 identifier: test_id.clone(), 262 }; 263 264 let as_extension = test_extension.into_extension().unwrap(); 265 266 assert_eq!(as_extension.extension_type, ExtensionType::APPLICATION_ID); 267 268 let restored = ApplicationIdExt::from_extension(&as_extension).unwrap(); 269 assert_eq!(restored.identifier, test_id); 270 } 271 272 #[test] test_ratchet_tree()273 fn test_ratchet_tree() { 274 let ext = RatchetTreeExt { 275 tree_data: ExportedTree::new(NodeVec::from(vec![None, None])), 276 }; 277 278 let as_extension = ext.clone().into_extension().unwrap(); 279 assert_eq!(as_extension.extension_type, ExtensionType::RATCHET_TREE); 280 281 let restored = RatchetTreeExt::from_extension(&as_extension).unwrap(); 282 assert_eq!(ext, restored) 283 } 284 285 #[test] test_required_capabilities()286 fn test_required_capabilities() { 287 let ext = RequiredCapabilitiesExt { 288 extensions: vec![0.into(), 1.into()], 289 proposals: vec![42.into(), 43.into()], 290 credentials: vec![BasicCredential::credential_type()], 291 }; 292 293 let as_extension = ext.clone().into_extension().unwrap(); 294 295 assert_eq!( 296 as_extension.extension_type, 297 ExtensionType::REQUIRED_CAPABILITIES 298 ); 299 300 let restored = RequiredCapabilitiesExt::from_extension(&as_extension).unwrap(); 301 assert_eq!(ext, restored) 302 } 303 304 #[cfg(feature = "by_ref_proposal")] 305 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_external_senders()306 async fn test_external_senders() { 307 let identity = get_test_signing_identity(TEST_CIPHER_SUITE, &[1]).await.0; 308 let ext = ExternalSendersExt::new(vec![identity]); 309 310 let as_extension = ext.clone().into_extension().unwrap(); 311 312 assert_eq!(as_extension.extension_type, ExtensionType::EXTERNAL_SENDERS); 313 314 let restored = ExternalSendersExt::from_extension(&as_extension).unwrap(); 315 assert_eq!(ext, restored) 316 } 317 318 #[test] test_external_pub()319 fn test_external_pub() { 320 let ext = ExternalPubExt { 321 external_pub: vec![0, 1, 2, 3].into(), 322 }; 323 324 let as_extension = ext.clone().into_extension().unwrap(); 325 assert_eq!(as_extension.extension_type, ExtensionType::EXTERNAL_PUB); 326 327 let restored = ExternalPubExt::from_extension(&as_extension).unwrap(); 328 assert_eq!(ext, restored) 329 } 330 } 331