1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec::Vec;
6 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
7 
8 use crate::CipherSuiteProvider;
9 
10 const REUSE_GUARD_SIZE: usize = 4;
11 
12 #[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
13 pub(crate) struct ReuseGuard([u8; REUSE_GUARD_SIZE]);
14 
15 impl From<[u8; REUSE_GUARD_SIZE]> for ReuseGuard {
from(value: [u8; REUSE_GUARD_SIZE]) -> Self16     fn from(value: [u8; REUSE_GUARD_SIZE]) -> Self {
17         ReuseGuard(value)
18     }
19 }
20 
21 impl From<ReuseGuard> for [u8; REUSE_GUARD_SIZE] {
from(value: ReuseGuard) -> Self22     fn from(value: ReuseGuard) -> Self {
23         value.0
24     }
25 }
26 
27 impl AsRef<[u8]> for ReuseGuard {
as_ref(&self) -> &[u8]28     fn as_ref(&self) -> &[u8] {
29         &self.0
30     }
31 }
32 
33 impl ReuseGuard {
random<P: CipherSuiteProvider>(provider: &P) -> Result<Self, P::Error>34     pub(crate) fn random<P: CipherSuiteProvider>(provider: &P) -> Result<Self, P::Error> {
35         let mut data = [0u8; REUSE_GUARD_SIZE];
36         provider.random_bytes(&mut data).map(|_| ReuseGuard(data))
37     }
38 
apply(&self, nonce: &[u8]) -> Vec<u8>39     pub(crate) fn apply(&self, nonce: &[u8]) -> Vec<u8> {
40         let mut new_nonce = nonce.to_vec();
41 
42         new_nonce
43             .iter_mut()
44             .zip(self.as_ref().iter())
45             .for_each(|(nonce_byte, guard_byte)| *nonce_byte ^= guard_byte);
46 
47         new_nonce
48     }
49 }
50 
51 #[cfg(test)]
52 mod test_utils {
53     use alloc::vec::Vec;
54 
55     use super::{ReuseGuard, REUSE_GUARD_SIZE};
56 
57     impl ReuseGuard {
new(guard: Vec<u8>) -> Self58         pub fn new(guard: Vec<u8>) -> Self {
59             let mut data = [0u8; REUSE_GUARD_SIZE];
60             data.copy_from_slice(&guard);
61             Self(data)
62         }
63     }
64 }
65 
66 #[cfg(test)]
67 mod tests {
68     use alloc::vec::Vec;
69     use mls_rs_core::crypto::CipherSuiteProvider;
70 
71     use crate::{
72         client::test_utils::TEST_CIPHER_SUITE, crypto::test_utils::test_cipher_suite_provider,
73     };
74 
75     use super::{ReuseGuard, REUSE_GUARD_SIZE};
76 
77     #[test]
test_random_generation()78     fn test_random_generation() {
79         let test_guard =
80             ReuseGuard::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE)).unwrap();
81 
82         (0..1000).for_each(|_| {
83             let next = ReuseGuard::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE)).unwrap();
84             assert_ne!(next, test_guard);
85         })
86     }
87 
88     #[derive(Debug, serde::Serialize, serde::Deserialize)]
89     struct TestCase {
90         nonce: Vec<u8>,
91         guard: [u8; REUSE_GUARD_SIZE],
92         result: Vec<u8>,
93     }
94 
95     #[cfg_attr(coverage_nightly, coverage(off))]
generate_reuse_guard_test_cases() -> Vec<TestCase>96     fn generate_reuse_guard_test_cases() -> Vec<TestCase> {
97         let provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
98 
99         [16, 32]
100             .into_iter()
101             .map(
102                 #[cfg_attr(coverage_nightly, coverage(off))]
103                 |len| {
104                     let nonce = provider.random_bytes_vec(len).unwrap();
105                     let guard = ReuseGuard::random(&provider).unwrap();
106 
107                     let result = guard.apply(&nonce);
108 
109                     TestCase {
110                         nonce,
111                         guard: guard.into(),
112                         result,
113                     }
114                 },
115             )
116             .collect()
117     }
118 
load_test_cases() -> Vec<TestCase>119     fn load_test_cases() -> Vec<TestCase> {
120         load_test_case_json!(reuse_guard, generate_reuse_guard_test_cases())
121     }
122 
123     #[test]
test_reuse_guard()124     fn test_reuse_guard() {
125         let test_cases = load_test_cases();
126 
127         for case in test_cases {
128             let guard = ReuseGuard::from(case.guard);
129             let result = guard.apply(&case.nonce);
130             assert_eq!(result, case.result);
131         }
132     }
133 }
134