1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use alloc::vec::Vec; 6 7 #[cfg(any(test, feature = "external_client"))] 8 use alloc::vec; 9 10 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 11 12 #[cfg(any(test, feature = "external_client"))] 13 use mls_rs_core::psk::PreSharedKeyStorage; 14 15 #[cfg(any(test, feature = "external_client"))] 16 use core::convert::Infallible; 17 use core::fmt::{self, Debug}; 18 19 #[cfg(feature = "psk")] 20 use crate::{client::MlsError, CipherSuiteProvider}; 21 22 #[cfg(feature = "psk")] 23 use mls_rs_core::error::IntoAnyError; 24 25 #[cfg(feature = "psk")] 26 pub(crate) mod resolver; 27 pub(crate) mod secret; 28 29 pub use mls_rs_core::psk::{ExternalPskId, PreSharedKey}; 30 31 #[derive(Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)] 32 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 33 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 34 pub(crate) struct PreSharedKeyID { 35 pub key_id: JustPreSharedKeyID, 36 pub psk_nonce: PskNonce, 37 } 38 39 impl PreSharedKeyID { 40 #[cfg(feature = "psk")] new<P: CipherSuiteProvider>( key_id: JustPreSharedKeyID, cs: &P, ) -> Result<Self, MlsError>41 pub(crate) fn new<P: CipherSuiteProvider>( 42 key_id: JustPreSharedKeyID, 43 cs: &P, 44 ) -> Result<Self, MlsError> { 45 Ok(Self { 46 key_id, 47 psk_nonce: PskNonce::random(cs) 48 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?, 49 }) 50 } 51 } 52 53 #[derive(Clone, Debug, Eq, Hash, Ord, PartialOrd, PartialEq, MlsSize, MlsEncode, MlsDecode)] 54 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 55 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 56 #[repr(u8)] 57 pub(crate) enum JustPreSharedKeyID { 58 External(ExternalPskId) = 1u8, 59 Resumption(ResumptionPsk) = 2u8, 60 } 61 62 #[derive(Clone, Eq, Hash, Ord, PartialOrd, PartialEq, MlsSize, MlsEncode, MlsDecode)] 63 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 64 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 65 pub(crate) struct PskGroupId( 66 #[mls_codec(with = "mls_rs_codec::byte_vec")] 67 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))] 68 pub Vec<u8>, 69 ); 70 71 impl Debug for PskGroupId { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 73 mls_rs_core::debug::pretty_bytes(&self.0) 74 .named("PskGroupId") 75 .fmt(f) 76 } 77 } 78 79 #[derive(Clone, Eq, Hash, PartialEq, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)] 80 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 81 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 82 pub(crate) struct PskNonce( 83 #[mls_codec(with = "mls_rs_codec::byte_vec")] 84 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))] 85 pub Vec<u8>, 86 ); 87 88 impl Debug for PskNonce { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 90 mls_rs_core::debug::pretty_bytes(&self.0) 91 .named("PskNonce") 92 .fmt(f) 93 } 94 } 95 96 #[cfg(feature = "psk")] 97 impl PskNonce { random<P: CipherSuiteProvider>( cipher_suite_provider: &P, ) -> Result<Self, <P as CipherSuiteProvider>::Error>98 pub fn random<P: CipherSuiteProvider>( 99 cipher_suite_provider: &P, 100 ) -> Result<Self, <P as CipherSuiteProvider>::Error> { 101 Ok(Self(cipher_suite_provider.random_bytes_vec( 102 cipher_suite_provider.kdf_extract_size(), 103 )?)) 104 } 105 } 106 107 #[derive(Clone, Debug, Eq, Hash, Ord, PartialOrd, PartialEq, MlsSize, MlsEncode, MlsDecode)] 108 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 109 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 110 pub(crate) struct ResumptionPsk { 111 pub usage: ResumptionPSKUsage, 112 pub psk_group_id: PskGroupId, 113 pub psk_epoch: u64, 114 } 115 116 #[derive(Clone, Debug, Eq, Hash, PartialEq, Ord, PartialOrd, MlsSize, MlsEncode, MlsDecode)] 117 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 118 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 119 #[repr(u8)] 120 pub(crate) enum ResumptionPSKUsage { 121 Application = 1u8, 122 Reinit = 2u8, 123 Branch = 3u8, 124 } 125 126 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode)] 127 struct PSKLabel<'a> { 128 id: &'a PreSharedKeyID, 129 index: u16, 130 count: u16, 131 } 132 133 #[cfg(any(test, feature = "external_client"))] 134 #[derive(Clone, Copy, Debug)] 135 pub(crate) struct AlwaysFoundPskStorage; 136 137 #[cfg(any(test, feature = "external_client"))] 138 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] 139 #[cfg_attr(mls_build_async, maybe_async::must_be_async)] 140 impl PreSharedKeyStorage for AlwaysFoundPskStorage { 141 type Error = Infallible; 142 get(&self, _: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error>143 async fn get(&self, _: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error> { 144 Ok(Some(vec![].into())) 145 } 146 } 147 148 #[cfg(feature = "psk")] 149 #[cfg(test)] 150 pub(crate) mod test_utils { 151 use crate::crypto::test_utils::test_cipher_suite_provider; 152 153 use super::PskNonce; 154 use mls_rs_core::crypto::CipherSuite; 155 156 #[cfg(not(mls_build_async))] 157 use mls_rs_core::{crypto::CipherSuiteProvider, psk::ExternalPskId}; 158 159 #[cfg_attr(coverage_nightly, coverage(off))] 160 #[cfg(not(mls_build_async))] make_external_psk_id<P: CipherSuiteProvider>( cipher_suite_provider: &P, ) -> ExternalPskId161 pub(crate) fn make_external_psk_id<P: CipherSuiteProvider>( 162 cipher_suite_provider: &P, 163 ) -> ExternalPskId { 164 ExternalPskId::new( 165 cipher_suite_provider 166 .random_bytes_vec(cipher_suite_provider.kdf_extract_size()) 167 .unwrap(), 168 ) 169 } 170 make_nonce(cipher_suite: CipherSuite) -> PskNonce171 pub(crate) fn make_nonce(cipher_suite: CipherSuite) -> PskNonce { 172 PskNonce::random(&test_cipher_suite_provider(cipher_suite)).unwrap() 173 } 174 } 175 176 #[cfg(feature = "psk")] 177 #[cfg(test)] 178 mod tests { 179 use crate::crypto::test_utils::TestCryptoProvider; 180 use core::iter; 181 182 #[cfg(target_arch = "wasm32")] 183 use wasm_bindgen_test::wasm_bindgen_test as test; 184 185 use super::test_utils::make_nonce; 186 187 #[test] random_generation_of_nonces_is_random()188 fn random_generation_of_nonces_is_random() { 189 let good = TestCryptoProvider::all_supported_cipher_suites() 190 .into_iter() 191 .all(|cipher_suite| { 192 let nonce = make_nonce(cipher_suite); 193 iter::repeat_with(|| make_nonce(cipher_suite)) 194 .take(1000) 195 .all(|other| other != nonce) 196 }); 197 198 assert!(good); 199 } 200 } 201