1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use super::framing::Content;
6 use crate::client::MlsError;
7 use crate::crypto::SignatureSecretKey;
8 use crate::group::framing::{ContentType, FramedContent, PublicMessage, Sender, WireFormat};
9 use crate::group::{ConfirmationTag, GroupContext};
10 use crate::signer::Signable;
11 use crate::CipherSuiteProvider;
12 use alloc::vec;
13 use alloc::vec::Vec;
14 use core::{
15     fmt::{self, Debug},
16     ops::Deref,
17 };
18 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
19 use mls_rs_core::protocol_version::ProtocolVersion;
20 
21 #[derive(Clone, Debug, PartialEq)]
22 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
23 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24 pub struct FramedContentAuthData {
25     pub signature: MessageSignature,
26     pub confirmation_tag: Option<ConfirmationTag>,
27 }
28 
29 impl MlsSize for FramedContentAuthData {
mls_encoded_len(&self) -> usize30     fn mls_encoded_len(&self) -> usize {
31         self.signature.mls_encoded_len()
32             + self
33                 .confirmation_tag
34                 .as_ref()
35                 .map_or(0, |tag| tag.mls_encoded_len())
36     }
37 }
38 
39 impl MlsEncode for FramedContentAuthData {
mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error>40     fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
41         self.signature.mls_encode(writer)?;
42 
43         if let Some(ref tag) = self.confirmation_tag {
44             tag.mls_encode(writer)?;
45         }
46 
47         Ok(())
48     }
49 }
50 
51 impl FramedContentAuthData {
mls_decode( reader: &mut &[u8], content_type: ContentType, ) -> Result<Self, mls_rs_codec::Error>52     pub(crate) fn mls_decode(
53         reader: &mut &[u8],
54         content_type: ContentType,
55     ) -> Result<Self, mls_rs_codec::Error> {
56         Ok(FramedContentAuthData {
57             signature: MessageSignature::mls_decode(reader)?,
58             confirmation_tag: match content_type {
59                 ContentType::Commit => Some(ConfirmationTag::mls_decode(reader)?),
60                 #[cfg(feature = "private_message")]
61                 ContentType::Application => None,
62                 #[cfg(feature = "by_ref_proposal")]
63                 ContentType::Proposal => None,
64             },
65         })
66     }
67 }
68 
69 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode)]
70 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
71 pub struct AuthenticatedContent {
72     pub(crate) wire_format: WireFormat,
73     pub(crate) content: FramedContent,
74     pub(crate) auth: FramedContentAuthData,
75 }
76 
77 impl From<PublicMessage> for AuthenticatedContent {
from(p: PublicMessage) -> Self78     fn from(p: PublicMessage) -> Self {
79         Self {
80             wire_format: WireFormat::PublicMessage,
81             content: p.content,
82             auth: p.auth,
83         }
84     }
85 }
86 
87 impl AuthenticatedContent {
new( context: &GroupContext, sender: Sender, content: Content, authenticated_data: Vec<u8>, wire_format: WireFormat, ) -> AuthenticatedContent88     pub(crate) fn new(
89         context: &GroupContext,
90         sender: Sender,
91         content: Content,
92         authenticated_data: Vec<u8>,
93         wire_format: WireFormat,
94     ) -> AuthenticatedContent {
95         AuthenticatedContent {
96             wire_format,
97             content: FramedContent {
98                 group_id: context.group_id.clone(),
99                 epoch: context.epoch,
100                 sender,
101                 authenticated_data,
102                 content,
103             },
104             auth: FramedContentAuthData {
105                 signature: MessageSignature::empty(),
106                 confirmation_tag: None,
107             },
108         }
109     }
110 
111     #[inline(never)]
112     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
new_signed<P: CipherSuiteProvider>( signature_provider: &P, context: &GroupContext, sender: Sender, content: Content, signer: &SignatureSecretKey, wire_format: WireFormat, authenticated_data: Vec<u8>, ) -> Result<AuthenticatedContent, MlsError>113     pub(crate) async fn new_signed<P: CipherSuiteProvider>(
114         signature_provider: &P,
115         context: &GroupContext,
116         sender: Sender,
117         content: Content,
118         signer: &SignatureSecretKey,
119         wire_format: WireFormat,
120         authenticated_data: Vec<u8>,
121     ) -> Result<AuthenticatedContent, MlsError> {
122         // Construct an MlsPlaintext object containing the content
123         let mut plaintext =
124             AuthenticatedContent::new(context, sender, content, authenticated_data, wire_format);
125 
126         let signing_context = MessageSigningContext {
127             group_context: Some(context),
128             protocol_version: context.protocol_version,
129         };
130 
131         // Sign the MlsPlaintext using the current epoch's GroupContext as context.
132         plaintext
133             .sign(signature_provider, signer, &signing_context)
134             .await?;
135 
136         Ok(plaintext)
137     }
138 }
139 
140 impl MlsDecode for AuthenticatedContent {
mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error>141     fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
142         let wire_format = WireFormat::mls_decode(reader)?;
143         let content = FramedContent::mls_decode(reader)?;
144         let auth_data = FramedContentAuthData::mls_decode(reader, content.content_type())?;
145 
146         Ok(AuthenticatedContent {
147             wire_format,
148             content,
149             auth: auth_data,
150         })
151     }
152 }
153 
154 #[derive(Clone, Debug, PartialEq)]
155 pub(crate) struct AuthenticatedContentTBS<'a> {
156     pub(crate) protocol_version: ProtocolVersion,
157     pub(crate) wire_format: WireFormat,
158     pub(crate) content: &'a FramedContent,
159     pub(crate) context: Option<&'a GroupContext>,
160 }
161 
162 impl<'a> MlsSize for AuthenticatedContentTBS<'a> {
mls_encoded_len(&self) -> usize163     fn mls_encoded_len(&self) -> usize {
164         self.protocol_version.mls_encoded_len()
165             + self.wire_format.mls_encoded_len()
166             + self.content.mls_encoded_len()
167             + self.context.as_ref().map_or(0, |ctx| ctx.mls_encoded_len())
168     }
169 }
170 
171 impl<'a> MlsEncode for AuthenticatedContentTBS<'a> {
mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error>172     fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
173         self.protocol_version.mls_encode(writer)?;
174         self.wire_format.mls_encode(writer)?;
175         self.content.mls_encode(writer)?;
176 
177         if let Some(context) = self.context {
178             context.mls_encode(writer)?;
179         }
180 
181         Ok(())
182     }
183 }
184 
185 impl<'a> AuthenticatedContentTBS<'a> {
186     /// The group context must not be `None` when the sender is `Member` or `NewMember`.
from_authenticated_content( auth_content: &'a AuthenticatedContent, group_context: Option<&'a GroupContext>, protocol_version: ProtocolVersion, ) -> Self187     pub(crate) fn from_authenticated_content(
188         auth_content: &'a AuthenticatedContent,
189         group_context: Option<&'a GroupContext>,
190         protocol_version: ProtocolVersion,
191     ) -> Self {
192         AuthenticatedContentTBS {
193             protocol_version,
194             wire_format: auth_content.wire_format,
195             content: &auth_content.content,
196             context: match auth_content.content.sender {
197                 Sender::Member(_) | Sender::NewMemberCommit => group_context,
198                 #[cfg(feature = "by_ref_proposal")]
199                 Sender::External(_) => None,
200                 #[cfg(feature = "by_ref_proposal")]
201                 Sender::NewMemberProposal => None,
202             },
203         }
204     }
205 }
206 
207 #[derive(Debug)]
208 pub(crate) struct MessageSigningContext<'a> {
209     pub group_context: Option<&'a GroupContext>,
210     pub protocol_version: ProtocolVersion,
211 }
212 
213 impl<'a> Signable<'a> for AuthenticatedContent {
214     const SIGN_LABEL: &'static str = "FramedContentTBS";
215 
216     type SigningContext = MessageSigningContext<'a>;
217 
signature(&self) -> &[u8]218     fn signature(&self) -> &[u8] {
219         &self.auth.signature
220     }
221 
signable_content( &self, context: &MessageSigningContext, ) -> Result<Vec<u8>, mls_rs_codec::Error>222     fn signable_content(
223         &self,
224         context: &MessageSigningContext,
225     ) -> Result<Vec<u8>, mls_rs_codec::Error> {
226         AuthenticatedContentTBS::from_authenticated_content(
227             self,
228             context.group_context,
229             context.protocol_version,
230         )
231         .mls_encode_to_vec()
232     }
233 
write_signature(&mut self, signature: Vec<u8>)234     fn write_signature(&mut self, signature: Vec<u8>) {
235         self.auth.signature = MessageSignature::from(signature)
236     }
237 }
238 
239 #[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
240 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
241 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
242 pub struct MessageSignature(
243     #[mls_codec(with = "mls_rs_codec::byte_vec")]
244     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
245     Vec<u8>,
246 );
247 
248 impl Debug for MessageSignature {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result249     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
250         mls_rs_core::debug::pretty_bytes(&self.0)
251             .named("MessageSignature")
252             .fmt(f)
253     }
254 }
255 
256 impl MessageSignature {
empty() -> Self257     pub(crate) fn empty() -> Self {
258         MessageSignature(vec![])
259     }
260 }
261 
262 impl Deref for MessageSignature {
263     type Target = Vec<u8>;
264 
deref(&self) -> &Self::Target265     fn deref(&self) -> &Self::Target {
266         &self.0
267     }
268 }
269 
270 impl From<Vec<u8>> for MessageSignature {
from(v: Vec<u8>) -> Self271     fn from(v: Vec<u8>) -> Self {
272         MessageSignature(v)
273     }
274 }
275