1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use super::framing::Content; 6 use crate::client::MlsError; 7 use crate::crypto::SignatureSecretKey; 8 use crate::group::framing::{ContentType, FramedContent, PublicMessage, Sender, WireFormat}; 9 use crate::group::{ConfirmationTag, GroupContext}; 10 use crate::signer::Signable; 11 use crate::CipherSuiteProvider; 12 use alloc::vec; 13 use alloc::vec::Vec; 14 use core::{ 15 fmt::{self, Debug}, 16 ops::Deref, 17 }; 18 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 19 use mls_rs_core::protocol_version::ProtocolVersion; 20 21 #[derive(Clone, Debug, PartialEq)] 22 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 23 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 24 pub struct FramedContentAuthData { 25 pub signature: MessageSignature, 26 pub confirmation_tag: Option<ConfirmationTag>, 27 } 28 29 impl MlsSize for FramedContentAuthData { mls_encoded_len(&self) -> usize30 fn mls_encoded_len(&self) -> usize { 31 self.signature.mls_encoded_len() 32 + self 33 .confirmation_tag 34 .as_ref() 35 .map_or(0, |tag| tag.mls_encoded_len()) 36 } 37 } 38 39 impl MlsEncode for FramedContentAuthData { mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error>40 fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> { 41 self.signature.mls_encode(writer)?; 42 43 if let Some(ref tag) = self.confirmation_tag { 44 tag.mls_encode(writer)?; 45 } 46 47 Ok(()) 48 } 49 } 50 51 impl FramedContentAuthData { mls_decode( reader: &mut &[u8], content_type: ContentType, ) -> Result<Self, mls_rs_codec::Error>52 pub(crate) fn mls_decode( 53 reader: &mut &[u8], 54 content_type: ContentType, 55 ) -> Result<Self, mls_rs_codec::Error> { 56 Ok(FramedContentAuthData { 57 signature: MessageSignature::mls_decode(reader)?, 58 confirmation_tag: match content_type { 59 ContentType::Commit => Some(ConfirmationTag::mls_decode(reader)?), 60 #[cfg(feature = "private_message")] 61 ContentType::Application => None, 62 #[cfg(feature = "by_ref_proposal")] 63 ContentType::Proposal => None, 64 }, 65 }) 66 } 67 } 68 69 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode)] 70 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 71 pub struct AuthenticatedContent { 72 pub(crate) wire_format: WireFormat, 73 pub(crate) content: FramedContent, 74 pub(crate) auth: FramedContentAuthData, 75 } 76 77 impl From<PublicMessage> for AuthenticatedContent { from(p: PublicMessage) -> Self78 fn from(p: PublicMessage) -> Self { 79 Self { 80 wire_format: WireFormat::PublicMessage, 81 content: p.content, 82 auth: p.auth, 83 } 84 } 85 } 86 87 impl AuthenticatedContent { new( context: &GroupContext, sender: Sender, content: Content, authenticated_data: Vec<u8>, wire_format: WireFormat, ) -> AuthenticatedContent88 pub(crate) fn new( 89 context: &GroupContext, 90 sender: Sender, 91 content: Content, 92 authenticated_data: Vec<u8>, 93 wire_format: WireFormat, 94 ) -> AuthenticatedContent { 95 AuthenticatedContent { 96 wire_format, 97 content: FramedContent { 98 group_id: context.group_id.clone(), 99 epoch: context.epoch, 100 sender, 101 authenticated_data, 102 content, 103 }, 104 auth: FramedContentAuthData { 105 signature: MessageSignature::empty(), 106 confirmation_tag: None, 107 }, 108 } 109 } 110 111 #[inline(never)] 112 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] new_signed<P: CipherSuiteProvider>( signature_provider: &P, context: &GroupContext, sender: Sender, content: Content, signer: &SignatureSecretKey, wire_format: WireFormat, authenticated_data: Vec<u8>, ) -> Result<AuthenticatedContent, MlsError>113 pub(crate) async fn new_signed<P: CipherSuiteProvider>( 114 signature_provider: &P, 115 context: &GroupContext, 116 sender: Sender, 117 content: Content, 118 signer: &SignatureSecretKey, 119 wire_format: WireFormat, 120 authenticated_data: Vec<u8>, 121 ) -> Result<AuthenticatedContent, MlsError> { 122 // Construct an MlsPlaintext object containing the content 123 let mut plaintext = 124 AuthenticatedContent::new(context, sender, content, authenticated_data, wire_format); 125 126 let signing_context = MessageSigningContext { 127 group_context: Some(context), 128 protocol_version: context.protocol_version, 129 }; 130 131 // Sign the MlsPlaintext using the current epoch's GroupContext as context. 132 plaintext 133 .sign(signature_provider, signer, &signing_context) 134 .await?; 135 136 Ok(plaintext) 137 } 138 } 139 140 impl MlsDecode for AuthenticatedContent { mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error>141 fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> { 142 let wire_format = WireFormat::mls_decode(reader)?; 143 let content = FramedContent::mls_decode(reader)?; 144 let auth_data = FramedContentAuthData::mls_decode(reader, content.content_type())?; 145 146 Ok(AuthenticatedContent { 147 wire_format, 148 content, 149 auth: auth_data, 150 }) 151 } 152 } 153 154 #[derive(Clone, Debug, PartialEq)] 155 pub(crate) struct AuthenticatedContentTBS<'a> { 156 pub(crate) protocol_version: ProtocolVersion, 157 pub(crate) wire_format: WireFormat, 158 pub(crate) content: &'a FramedContent, 159 pub(crate) context: Option<&'a GroupContext>, 160 } 161 162 impl<'a> MlsSize for AuthenticatedContentTBS<'a> { mls_encoded_len(&self) -> usize163 fn mls_encoded_len(&self) -> usize { 164 self.protocol_version.mls_encoded_len() 165 + self.wire_format.mls_encoded_len() 166 + self.content.mls_encoded_len() 167 + self.context.as_ref().map_or(0, |ctx| ctx.mls_encoded_len()) 168 } 169 } 170 171 impl<'a> MlsEncode for AuthenticatedContentTBS<'a> { mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error>172 fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> { 173 self.protocol_version.mls_encode(writer)?; 174 self.wire_format.mls_encode(writer)?; 175 self.content.mls_encode(writer)?; 176 177 if let Some(context) = self.context { 178 context.mls_encode(writer)?; 179 } 180 181 Ok(()) 182 } 183 } 184 185 impl<'a> AuthenticatedContentTBS<'a> { 186 /// The group context must not be `None` when the sender is `Member` or `NewMember`. from_authenticated_content( auth_content: &'a AuthenticatedContent, group_context: Option<&'a GroupContext>, protocol_version: ProtocolVersion, ) -> Self187 pub(crate) fn from_authenticated_content( 188 auth_content: &'a AuthenticatedContent, 189 group_context: Option<&'a GroupContext>, 190 protocol_version: ProtocolVersion, 191 ) -> Self { 192 AuthenticatedContentTBS { 193 protocol_version, 194 wire_format: auth_content.wire_format, 195 content: &auth_content.content, 196 context: match auth_content.content.sender { 197 Sender::Member(_) | Sender::NewMemberCommit => group_context, 198 #[cfg(feature = "by_ref_proposal")] 199 Sender::External(_) => None, 200 #[cfg(feature = "by_ref_proposal")] 201 Sender::NewMemberProposal => None, 202 }, 203 } 204 } 205 } 206 207 #[derive(Debug)] 208 pub(crate) struct MessageSigningContext<'a> { 209 pub group_context: Option<&'a GroupContext>, 210 pub protocol_version: ProtocolVersion, 211 } 212 213 impl<'a> Signable<'a> for AuthenticatedContent { 214 const SIGN_LABEL: &'static str = "FramedContentTBS"; 215 216 type SigningContext = MessageSigningContext<'a>; 217 signature(&self) -> &[u8]218 fn signature(&self) -> &[u8] { 219 &self.auth.signature 220 } 221 signable_content( &self, context: &MessageSigningContext, ) -> Result<Vec<u8>, mls_rs_codec::Error>222 fn signable_content( 223 &self, 224 context: &MessageSigningContext, 225 ) -> Result<Vec<u8>, mls_rs_codec::Error> { 226 AuthenticatedContentTBS::from_authenticated_content( 227 self, 228 context.group_context, 229 context.protocol_version, 230 ) 231 .mls_encode_to_vec() 232 } 233 write_signature(&mut self, signature: Vec<u8>)234 fn write_signature(&mut self, signature: Vec<u8>) { 235 self.auth.signature = MessageSignature::from(signature) 236 } 237 } 238 239 #[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)] 240 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 241 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 242 pub struct MessageSignature( 243 #[mls_codec(with = "mls_rs_codec::byte_vec")] 244 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))] 245 Vec<u8>, 246 ); 247 248 impl Debug for MessageSignature { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result249 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 250 mls_rs_core::debug::pretty_bytes(&self.0) 251 .named("MessageSignature") 252 .fmt(f) 253 } 254 } 255 256 impl MessageSignature { empty() -> Self257 pub(crate) fn empty() -> Self { 258 MessageSignature(vec![]) 259 } 260 } 261 262 impl Deref for MessageSignature { 263 type Target = Vec<u8>; 264 deref(&self) -> &Self::Target265 fn deref(&self) -> &Self::Target { 266 &self.0 267 } 268 } 269 270 impl From<Vec<u8>> for MessageSignature { from(v: Vec<u8>) -> Self271 fn from(v: Vec<u8>) -> Self { 272 MessageSignature(v) 273 } 274 } 275