1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec::Vec;
6 
7 #[cfg(any(test, feature = "external_client"))]
8 use alloc::vec;
9 
10 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
11 
12 #[cfg(any(test, feature = "external_client"))]
13 use mls_rs_core::psk::PreSharedKeyStorage;
14 
15 #[cfg(any(test, feature = "external_client"))]
16 use core::convert::Infallible;
17 use core::fmt::{self, Debug};
18 
19 #[cfg(feature = "psk")]
20 use crate::{client::MlsError, CipherSuiteProvider};
21 
22 #[cfg(feature = "psk")]
23 use mls_rs_core::error::IntoAnyError;
24 
25 #[cfg(feature = "psk")]
26 pub(crate) mod resolver;
27 pub(crate) mod secret;
28 
29 pub use mls_rs_core::psk::{ExternalPskId, PreSharedKey};
30 
31 #[derive(Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)]
32 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
33 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
34 pub(crate) struct PreSharedKeyID {
35     pub key_id: JustPreSharedKeyID,
36     pub psk_nonce: PskNonce,
37 }
38 
39 impl PreSharedKeyID {
40     #[cfg(feature = "psk")]
new<P: CipherSuiteProvider>( key_id: JustPreSharedKeyID, cs: &P, ) -> Result<Self, MlsError>41     pub(crate) fn new<P: CipherSuiteProvider>(
42         key_id: JustPreSharedKeyID,
43         cs: &P,
44     ) -> Result<Self, MlsError> {
45         Ok(Self {
46             key_id,
47             psk_nonce: PskNonce::random(cs)
48                 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?,
49         })
50     }
51 }
52 
53 #[derive(Clone, Debug, Eq, Hash, Ord, PartialOrd, PartialEq, MlsSize, MlsEncode, MlsDecode)]
54 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
55 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
56 #[repr(u8)]
57 pub(crate) enum JustPreSharedKeyID {
58     External(ExternalPskId) = 1u8,
59     Resumption(ResumptionPsk) = 2u8,
60 }
61 
62 #[derive(Clone, Eq, Hash, Ord, PartialOrd, PartialEq, MlsSize, MlsEncode, MlsDecode)]
63 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
64 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
65 pub(crate) struct PskGroupId(
66     #[mls_codec(with = "mls_rs_codec::byte_vec")]
67     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
68     pub Vec<u8>,
69 );
70 
71 impl Debug for PskGroupId {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result72     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73         mls_rs_core::debug::pretty_bytes(&self.0)
74             .named("PskGroupId")
75             .fmt(f)
76     }
77 }
78 
79 #[derive(Clone, Eq, Hash, PartialEq, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)]
80 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
81 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
82 pub(crate) struct PskNonce(
83     #[mls_codec(with = "mls_rs_codec::byte_vec")]
84     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
85     pub Vec<u8>,
86 );
87 
88 impl Debug for PskNonce {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result89     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90         mls_rs_core::debug::pretty_bytes(&self.0)
91             .named("PskNonce")
92             .fmt(f)
93     }
94 }
95 
96 #[cfg(feature = "psk")]
97 impl PskNonce {
random<P: CipherSuiteProvider>( cipher_suite_provider: &P, ) -> Result<Self, <P as CipherSuiteProvider>::Error>98     pub fn random<P: CipherSuiteProvider>(
99         cipher_suite_provider: &P,
100     ) -> Result<Self, <P as CipherSuiteProvider>::Error> {
101         Ok(Self(cipher_suite_provider.random_bytes_vec(
102             cipher_suite_provider.kdf_extract_size(),
103         )?))
104     }
105 }
106 
107 #[derive(Clone, Debug, Eq, Hash, Ord, PartialOrd, PartialEq, MlsSize, MlsEncode, MlsDecode)]
108 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
109 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
110 pub(crate) struct ResumptionPsk {
111     pub usage: ResumptionPSKUsage,
112     pub psk_group_id: PskGroupId,
113     pub psk_epoch: u64,
114 }
115 
116 #[derive(Clone, Debug, Eq, Hash, PartialEq, Ord, PartialOrd, MlsSize, MlsEncode, MlsDecode)]
117 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
118 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
119 #[repr(u8)]
120 pub(crate) enum ResumptionPSKUsage {
121     Application = 1u8,
122     Reinit = 2u8,
123     Branch = 3u8,
124 }
125 
126 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode)]
127 struct PSKLabel<'a> {
128     id: &'a PreSharedKeyID,
129     index: u16,
130     count: u16,
131 }
132 
133 #[cfg(any(test, feature = "external_client"))]
134 #[derive(Clone, Copy, Debug)]
135 pub(crate) struct AlwaysFoundPskStorage;
136 
137 #[cfg(any(test, feature = "external_client"))]
138 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
139 #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
140 impl PreSharedKeyStorage for AlwaysFoundPskStorage {
141     type Error = Infallible;
142 
get(&self, _: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error>143     async fn get(&self, _: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error> {
144         Ok(Some(vec![].into()))
145     }
146 }
147 
148 #[cfg(feature = "psk")]
149 #[cfg(test)]
150 pub(crate) mod test_utils {
151     use crate::crypto::test_utils::test_cipher_suite_provider;
152 
153     use super::PskNonce;
154     use mls_rs_core::crypto::CipherSuite;
155 
156     #[cfg(not(mls_build_async))]
157     use mls_rs_core::{crypto::CipherSuiteProvider, psk::ExternalPskId};
158 
159     #[cfg_attr(coverage_nightly, coverage(off))]
160     #[cfg(not(mls_build_async))]
make_external_psk_id<P: CipherSuiteProvider>( cipher_suite_provider: &P, ) -> ExternalPskId161     pub(crate) fn make_external_psk_id<P: CipherSuiteProvider>(
162         cipher_suite_provider: &P,
163     ) -> ExternalPskId {
164         ExternalPskId::new(
165             cipher_suite_provider
166                 .random_bytes_vec(cipher_suite_provider.kdf_extract_size())
167                 .unwrap(),
168         )
169     }
170 
make_nonce(cipher_suite: CipherSuite) -> PskNonce171     pub(crate) fn make_nonce(cipher_suite: CipherSuite) -> PskNonce {
172         PskNonce::random(&test_cipher_suite_provider(cipher_suite)).unwrap()
173     }
174 }
175 
176 #[cfg(feature = "psk")]
177 #[cfg(test)]
178 mod tests {
179     use crate::crypto::test_utils::TestCryptoProvider;
180     use core::iter;
181 
182     #[cfg(target_arch = "wasm32")]
183     use wasm_bindgen_test::wasm_bindgen_test as test;
184 
185     use super::test_utils::make_nonce;
186 
187     #[test]
random_generation_of_nonces_is_random()188     fn random_generation_of_nonces_is_random() {
189         let good = TestCryptoProvider::all_supported_cipher_suites()
190             .into_iter()
191             .all(|cipher_suite| {
192                 let nonce = make_nonce(cipher_suite);
193                 iter::repeat_with(|| make_nonce(cipher_suite))
194                     .take(1000)
195                     .all(|other| other != nonce)
196             });
197 
198         assert!(good);
199     }
200 }
201