1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec::Vec;
6 use core::fmt::{self, Debug};
7 use mls_rs_codec::{MlsEncode, MlsSize};
8 use mls_rs_core::error::IntoAnyError;
9 
10 use crate::client::MlsError;
11 use crate::crypto::{CipherSuiteProvider, SignaturePublicKey, SignatureSecretKey};
12 
13 #[derive(Clone, MlsSize, MlsEncode)]
14 struct SignContent {
15     #[mls_codec(with = "mls_rs_codec::byte_vec")]
16     label: Vec<u8>,
17     #[mls_codec(with = "mls_rs_codec::byte_vec")]
18     content: Vec<u8>,
19 }
20 
21 impl Debug for SignContent {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result22     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23         f.debug_struct("SignContent")
24             .field("label", &mls_rs_core::debug::pretty_bytes(&self.label))
25             .field("content", &mls_rs_core::debug::pretty_bytes(&self.content))
26             .finish()
27     }
28 }
29 
30 impl SignContent {
new(label: &str, content: Vec<u8>) -> Self31     pub fn new(label: &str, content: Vec<u8>) -> Self {
32         Self {
33             label: [b"MLS 1.0 ", label.as_bytes()].concat(),
34             content,
35         }
36     }
37 }
38 
39 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
40 #[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
41 #[cfg_attr(
42     all(not(target_arch = "wasm32"), mls_build_async),
43     maybe_async::must_be_async
44 )]
45 pub(crate) trait Signable<'a> {
46     const SIGN_LABEL: &'static str;
47 
48     type SigningContext: Send + Sync;
49 
signature(&self) -> &[u8]50     fn signature(&self) -> &[u8];
51 
signable_content( &self, context: &Self::SigningContext, ) -> Result<Vec<u8>, mls_rs_codec::Error>52     fn signable_content(
53         &self,
54         context: &Self::SigningContext,
55     ) -> Result<Vec<u8>, mls_rs_codec::Error>;
56 
write_signature(&mut self, signature: Vec<u8>)57     fn write_signature(&mut self, signature: Vec<u8>);
58 
sign<P: CipherSuiteProvider>( &mut self, signature_provider: &P, signer: &SignatureSecretKey, context: &Self::SigningContext, ) -> Result<(), MlsError>59     async fn sign<P: CipherSuiteProvider>(
60         &mut self,
61         signature_provider: &P,
62         signer: &SignatureSecretKey,
63         context: &Self::SigningContext,
64     ) -> Result<(), MlsError> {
65         let sign_content = SignContent::new(Self::SIGN_LABEL, self.signable_content(context)?);
66 
67         let signature = signature_provider
68             .sign(signer, &sign_content.mls_encode_to_vec()?)
69             .await
70             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
71 
72         self.write_signature(signature);
73 
74         Ok(())
75     }
76 
verify<P: CipherSuiteProvider>( &self, signature_provider: &P, public_key: &SignaturePublicKey, context: &Self::SigningContext, ) -> Result<(), MlsError>77     async fn verify<P: CipherSuiteProvider>(
78         &self,
79         signature_provider: &P,
80         public_key: &SignaturePublicKey,
81         context: &Self::SigningContext,
82     ) -> Result<(), MlsError> {
83         let sign_content = SignContent::new(Self::SIGN_LABEL, self.signable_content(context)?);
84 
85         signature_provider
86             .verify(
87                 public_key,
88                 self.signature(),
89                 &sign_content.mls_encode_to_vec()?,
90             )
91             .await
92             .map_err(|_| MlsError::InvalidSignature)
93     }
94 }
95 
96 #[cfg(test)]
97 pub(crate) mod test_utils {
98     use alloc::vec;
99     use alloc::{string::String, vec::Vec};
100     use mls_rs_core::crypto::CipherSuiteProvider;
101 
102     use crate::crypto::test_utils::try_test_cipher_suite_provider;
103 
104     use super::Signable;
105 
106     #[derive(Debug, serde::Serialize, serde::Deserialize)]
107     pub struct SignatureInteropTestCase {
108         #[serde(with = "hex::serde", rename = "priv")]
109         secret: Vec<u8>,
110         #[serde(with = "hex::serde", rename = "pub")]
111         public: Vec<u8>,
112         #[serde(with = "hex::serde")]
113         content: Vec<u8>,
114         label: String,
115         #[serde(with = "hex::serde")]
116         signature: Vec<u8>,
117     }
118 
119     #[derive(Debug, serde::Serialize, serde::Deserialize)]
120     pub struct InteropTestCase {
121         cipher_suite: u16,
122         sign_with_label: SignatureInteropTestCase,
123     }
124 
125     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_basic_crypto_test_vectors()126     async fn test_basic_crypto_test_vectors() {
127         let test_cases: Vec<InteropTestCase> =
128             load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
129 
130         for test_case in test_cases {
131             if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
132                 test_case.sign_with_label.verify(&cs).await;
133             }
134         }
135     }
136 
137     pub struct TestSignable {
138         pub content: Vec<u8>,
139         pub signature: Vec<u8>,
140     }
141 
142     impl<'a> Signable<'a> for TestSignable {
143         const SIGN_LABEL: &'static str = "SignWithLabel";
144 
145         type SigningContext = Vec<u8>;
146 
signature(&self) -> &[u8]147         fn signature(&self) -> &[u8] {
148             &self.signature
149         }
150 
signable_content( &self, context: &Self::SigningContext, ) -> Result<Vec<u8>, mls_rs_codec::Error>151         fn signable_content(
152             &self,
153             context: &Self::SigningContext,
154         ) -> Result<Vec<u8>, mls_rs_codec::Error> {
155             Ok([context.as_slice(), self.content.as_slice()].concat())
156         }
157 
write_signature(&mut self, signature: Vec<u8>)158         fn write_signature(&mut self, signature: Vec<u8>) {
159             self.signature = signature
160         }
161     }
162 
163     impl SignatureInteropTestCase {
164         #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify<P: CipherSuiteProvider>(&self, cs: &P)165         pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) {
166             let public = self.public.clone().into();
167 
168             let signable = TestSignable {
169                 content: self.content.clone(),
170                 signature: self.signature.clone(),
171             };
172 
173             signable.verify(cs, &public, &vec![]).await.unwrap();
174         }
175     }
176 }
177 
178 #[cfg(test)]
179 mod tests {
180     use super::{test_utils::TestSignable, *};
181     use crate::{
182         client::test_utils::TEST_CIPHER_SUITE,
183         crypto::test_utils::{
184             test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider,
185         },
186         group::test_utils::random_bytes,
187     };
188     use alloc::vec;
189     use assert_matches::assert_matches;
190 
191     #[derive(Debug, serde::Serialize, serde::Deserialize)]
192     struct TestCase {
193         cipher_suite: u16,
194         #[serde(with = "hex::serde")]
195         content: Vec<u8>,
196         #[serde(with = "hex::serde")]
197         context: Vec<u8>,
198         #[serde(with = "hex::serde")]
199         signature: Vec<u8>,
200         #[serde(with = "hex::serde")]
201         signer: Vec<u8>,
202         #[serde(with = "hex::serde")]
203         public: Vec<u8>,
204     }
205 
206     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
207     #[cfg_attr(coverage_nightly, coverage(off))]
generate_test_cases() -> Vec<TestCase>208     async fn generate_test_cases() -> Vec<TestCase> {
209         let mut test_cases = Vec::new();
210 
211         for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
212             let provider = test_cipher_suite_provider(cipher_suite);
213 
214             let (signer, public) = provider.signature_key_generate().await.unwrap();
215 
216             let content = random_bytes(32);
217             let context = random_bytes(32);
218 
219             let mut test_signable = TestSignable {
220                 content: content.clone(),
221                 signature: Vec::new(),
222             };
223 
224             test_signable
225                 .sign(&provider, &signer, &context)
226                 .await
227                 .unwrap();
228 
229             test_cases.push(TestCase {
230                 cipher_suite: cipher_suite.into(),
231                 content,
232                 context,
233                 signature: test_signable.signature,
234                 signer: signer.to_vec(),
235                 public: public.to_vec(),
236             });
237         }
238 
239         test_cases
240     }
241 
242     #[cfg(mls_build_async)]
load_test_cases() -> Vec<TestCase>243     async fn load_test_cases() -> Vec<TestCase> {
244         load_test_case_json!(signatures, generate_test_cases().await)
245     }
246 
247     #[cfg(not(mls_build_async))]
load_test_cases() -> Vec<TestCase>248     fn load_test_cases() -> Vec<TestCase> {
249         load_test_case_json!(signatures, generate_test_cases())
250     }
251 
252     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_signatures()253     async fn test_signatures() {
254         let cases = load_test_cases().await;
255 
256         for one_case in cases {
257             let Some(cipher_suite_provider) = try_test_cipher_suite_provider(one_case.cipher_suite)
258             else {
259                 continue;
260             };
261 
262             let public_key = SignaturePublicKey::from(one_case.public);
263 
264             // Wasm uses incompatible signature secret key format
265             #[cfg(not(target_arch = "wasm32"))]
266             {
267                 // Test signature generation
268                 let mut test_signable = TestSignable {
269                     content: one_case.content.clone(),
270                     signature: Vec::new(),
271                 };
272 
273                 let signature_key = SignatureSecretKey::from(one_case.signer);
274 
275                 test_signable
276                     .sign(&cipher_suite_provider, &signature_key, &one_case.context)
277                     .await
278                     .unwrap();
279 
280                 test_signable
281                     .verify(&cipher_suite_provider, &public_key, &one_case.context)
282                     .await
283                     .unwrap();
284             }
285 
286             // Test verifying an existing signature
287             let test_signable = TestSignable {
288                 content: one_case.content,
289                 signature: one_case.signature,
290             };
291 
292             test_signable
293                 .verify(&cipher_suite_provider, &public_key, &one_case.context)
294                 .await
295                 .unwrap();
296         }
297     }
298 
299     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_invalid_signature()300     async fn test_invalid_signature() {
301         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
302 
303         let (correct_secret, _) = cipher_suite_provider
304             .signature_key_generate()
305             .await
306             .unwrap();
307         let (_, incorrect_public) = cipher_suite_provider
308             .signature_key_generate()
309             .await
310             .unwrap();
311 
312         let mut test_signable = TestSignable {
313             content: random_bytes(32),
314             signature: vec![],
315         };
316 
317         test_signable
318             .sign(&cipher_suite_provider, &correct_secret, &vec![])
319             .await
320             .unwrap();
321 
322         let res = test_signable
323             .verify(&cipher_suite_provider, &incorrect_public, &vec![])
324             .await;
325 
326         assert_matches!(res, Err(MlsError::InvalidSignature));
327     }
328 
329     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_invalid_context()330     async fn test_invalid_context() {
331         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
332 
333         let (secret, public) = cipher_suite_provider
334             .signature_key_generate()
335             .await
336             .unwrap();
337 
338         let correct_context = random_bytes(32);
339         let incorrect_context = random_bytes(32);
340 
341         let mut test_signable = TestSignable {
342             content: random_bytes(32),
343             signature: vec![],
344         };
345 
346         test_signable
347             .sign(&cipher_suite_provider, &secret, &correct_context)
348             .await
349             .unwrap();
350 
351         let res = test_signable
352             .verify(&cipher_suite_provider, &public, &incorrect_context)
353             .await;
354 
355         assert_matches!(res, Err(MlsError::InvalidSignature));
356     }
357 }
358