1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 #[cfg(feature = "psk")]
6 use crate::psk::PreSharedKey;
7 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
8 use crate::tree_kem::node::NodeIndex;
9 #[cfg(feature = "prior_epoch")]
10 use crate::{crypto::SignaturePublicKey, group::GroupContext, tree_kem::node::LeafIndex};
11 use alloc::vec::Vec;
12 use core::{
13     fmt::{self, Debug},
14     ops::Deref,
15 };
16 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
17 use zeroize::Zeroizing;
18 
19 #[cfg(all(feature = "prior_epoch", feature = "private_message"))]
20 use super::ciphertext_processor::GroupStateProvider;
21 
22 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
23 use crate::group::secret_tree::SecretTree;
24 
25 #[cfg(feature = "prior_epoch")]
26 #[derive(Debug, Clone, MlsEncode, MlsDecode, MlsSize, PartialEq)]
27 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
28 pub(crate) struct PriorEpoch {
29     pub(crate) context: GroupContext,
30     pub(crate) self_index: LeafIndex,
31     pub(crate) secrets: EpochSecrets,
32     pub(crate) signature_public_keys: Vec<Option<SignaturePublicKey>>,
33 }
34 
35 #[cfg(feature = "prior_epoch")]
36 impl PriorEpoch {
37     #[inline(always)]
epoch_id(&self) -> u6438     pub(crate) fn epoch_id(&self) -> u64 {
39         self.context.epoch
40     }
41 
42     #[inline(always)]
group_id(&self) -> &[u8]43     pub(crate) fn group_id(&self) -> &[u8] {
44         &self.context.group_id
45     }
46 }
47 
48 #[cfg(all(feature = "private_message", feature = "prior_epoch"))]
49 impl GroupStateProvider for PriorEpoch {
group_context(&self) -> &GroupContext50     fn group_context(&self) -> &GroupContext {
51         &self.context
52     }
53 
self_index(&self) -> LeafIndex54     fn self_index(&self) -> LeafIndex {
55         self.self_index
56     }
57 
epoch_secrets_mut(&mut self) -> &mut EpochSecrets58     fn epoch_secrets_mut(&mut self) -> &mut EpochSecrets {
59         &mut self.secrets
60     }
61 
epoch_secrets(&self) -> &EpochSecrets62     fn epoch_secrets(&self) -> &EpochSecrets {
63         &self.secrets
64     }
65 }
66 
67 #[derive(Debug, Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
68 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
69 pub(crate) struct EpochSecrets {
70     #[cfg(feature = "psk")]
71     #[mls_codec(with = "mls_rs_codec::byte_vec")]
72     pub(crate) resumption_secret: PreSharedKey,
73     #[mls_codec(with = "mls_rs_codec::byte_vec")]
74     pub(crate) sender_data_secret: SenderDataSecret,
75     #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
76     pub(crate) secret_tree: SecretTree<NodeIndex>,
77 }
78 
79 #[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
80 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
81 pub(crate) struct SenderDataSecret(
82     #[mls_codec(with = "mls_rs_codec::byte_vec")]
83     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
84     Zeroizing<Vec<u8>>,
85 );
86 
87 impl Debug for SenderDataSecret {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result88     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89         mls_rs_core::debug::pretty_bytes(&self.0)
90             .named("SenderDataSecret")
91             .fmt(f)
92     }
93 }
94 
95 impl AsRef<[u8]> for SenderDataSecret {
as_ref(&self) -> &[u8]96     fn as_ref(&self) -> &[u8] {
97         &self.0
98     }
99 }
100 
101 impl Deref for SenderDataSecret {
102     type Target = Vec<u8>;
103 
deref(&self) -> &Self::Target104     fn deref(&self) -> &Self::Target {
105         &self.0
106     }
107 }
108 
109 impl From<Vec<u8>> for SenderDataSecret {
from(bytes: Vec<u8>) -> Self110     fn from(bytes: Vec<u8>) -> Self {
111         Self(Zeroizing::new(bytes))
112     }
113 }
114 
115 impl From<Zeroizing<Vec<u8>>> for SenderDataSecret {
from(bytes: Zeroizing<Vec<u8>>) -> Self116     fn from(bytes: Zeroizing<Vec<u8>>) -> Self {
117         Self(bytes)
118     }
119 }
120 
121 #[cfg(test)]
122 pub(crate) mod test_utils {
123     use mls_rs_core::crypto::CipherSuiteProvider;
124 
125     use super::*;
126     use crate::cipher_suite::CipherSuite;
127     use crate::crypto::test_utils::test_cipher_suite_provider;
128 
129     #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
130     use crate::group::secret_tree::test_utils::get_test_tree;
131 
132     #[cfg(feature = "prior_epoch")]
133     use crate::group::test_utils::get_test_group_context_with_id;
134 
135     use crate::group::test_utils::random_bytes;
136 
get_test_epoch_secrets(cipher_suite: CipherSuite) -> EpochSecrets137     pub(crate) fn get_test_epoch_secrets(cipher_suite: CipherSuite) -> EpochSecrets {
138         let cs_provider = test_cipher_suite_provider(cipher_suite);
139 
140         #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
141         let secret_tree = get_test_tree(random_bytes(cs_provider.kdf_extract_size()), 2);
142 
143         EpochSecrets {
144             #[cfg(feature = "psk")]
145             resumption_secret: random_bytes(cs_provider.kdf_extract_size()).into(),
146             sender_data_secret: random_bytes(cs_provider.kdf_extract_size()).into(),
147             #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
148             secret_tree,
149         }
150     }
151 
152     #[cfg(feature = "prior_epoch")]
get_test_epoch_with_id( group_id: Vec<u8>, cipher_suite: CipherSuite, id: u64, ) -> PriorEpoch153     pub(crate) fn get_test_epoch_with_id(
154         group_id: Vec<u8>,
155         cipher_suite: CipherSuite,
156         id: u64,
157     ) -> PriorEpoch {
158         PriorEpoch {
159             context: get_test_group_context_with_id(group_id, id, cipher_suite),
160             self_index: LeafIndex(0),
161             secrets: get_test_epoch_secrets(cipher_suite),
162             signature_public_keys: Default::default(),
163         }
164     }
165 }
166