1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use core::{
6     fmt::{self, Debug},
7     ops::Deref,
8 };
9 
10 use alloc::vec::Vec;
11 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
12 
13 use super::BasicCredential;
14 
15 #[cfg(feature = "x509")]
16 use super::CertificateChain;
17 
18 /// Wrapper type representing a credential type identifier along with default
19 /// values defined by the MLS RFC.
20 #[derive(
21     Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode,
22 )]
23 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
24 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)]
25 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26 #[repr(transparent)]
27 pub struct CredentialType(u16);
28 
29 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
30 impl CredentialType {
31     /// Basic identity.
32     pub const BASIC: CredentialType = CredentialType(1);
33 
34     #[cfg(feature = "x509")]
35     /// X509 Certificate Identity.
36     pub const X509: CredentialType = CredentialType(2);
37 
new(raw_value: u16) -> Self38     pub const fn new(raw_value: u16) -> Self {
39         CredentialType(raw_value)
40     }
41 
raw_value(&self) -> u1642     pub const fn raw_value(&self) -> u16 {
43         self.0
44     }
45 }
46 
47 impl From<u16> for CredentialType {
from(value: u16) -> Self48     fn from(value: u16) -> Self {
49         CredentialType(value)
50     }
51 }
52 
53 impl Deref for CredentialType {
54     type Target = u16;
55 
deref(&self) -> &Self::Target56     fn deref(&self) -> &Self::Target {
57         &self.0
58     }
59 }
60 
61 #[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq, Eq, Hash, PartialOrd, Ord)]
62 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
63 #[cfg_attr(
64     all(feature = "ffi", not(test)),
65     safer_ffi_gen::ffi_type(clone, opaque)
66 )]
67 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68 /// Custom user created credential type.
69 ///
70 /// # Warning
71 ///
72 /// In order to use a custom credential within an MLS group, a supporting
73 /// [`IdentityProvider`](crate::identity::IdentityProvider) must be created that can
74 /// authenticate the credential.
75 pub struct CustomCredential {
76     /// Unique credential type to identify this custom credential.
77     pub credential_type: CredentialType,
78     /// Opaque data representing this custom credential.
79     #[mls_codec(with = "mls_rs_codec::byte_vec")]
80     #[cfg_attr(feature = "serde", serde(with = "crate::vec_serde"))]
81     pub data: Vec<u8>,
82 }
83 
84 impl Debug for CustomCredential {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result85     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86         f.debug_struct("CustomCredential")
87             .field("credential_type", &self.credential_type)
88             .field("data", &crate::debug::pretty_bytes(&self.data))
89             .finish()
90     }
91 }
92 
93 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
94 impl CustomCredential {
95     /// Create a new custom credential with opaque data.
96     ///
97     /// # Warning
98     ///
99     /// Using any of the constants defined within [`CredentialType`] will
100     /// result in unspecified behavior.
new(credential_type: CredentialType, data: Vec<u8>) -> CustomCredential101     pub fn new(credential_type: CredentialType, data: Vec<u8>) -> CustomCredential {
102         CustomCredential {
103             credential_type,
104             data,
105         }
106     }
107 
108     /// Unique credential type to identify this custom credential.
109     #[cfg(feature = "ffi")]
credential_type(&self) -> CredentialType110     pub fn credential_type(&self) -> CredentialType {
111         self.credential_type
112     }
113 
114     /// Opaque data representing this custom credential.
115     #[cfg(feature = "ffi")]
data(&self) -> &[u8]116     pub fn data(&self) -> &[u8] {
117         &self.data
118     }
119 }
120 
121 /// A MLS credential used to authenticate a group member.
122 #[derive(Clone, Debug, PartialEq, Ord, PartialOrd, Eq, Hash)]
123 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
124 #[cfg_attr(
125     all(feature = "ffi", not(test)),
126     safer_ffi_gen::ffi_type(clone, opaque)
127 )]
128 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
129 #[non_exhaustive]
130 pub enum Credential {
131     /// Basic identifier-only credential.
132     ///
133     /// # Warning
134     ///
135     /// Basic credentials are inherently insecure since they can not be
136     /// properly validated. It is not recommended to use [`BasicCredential`]
137     /// in production applications.
138     Basic(BasicCredential),
139     #[cfg(feature = "x509")]
140     /// X.509 Certificate chain.
141     X509(CertificateChain),
142     /// User provided custom credential.
143     Custom(CustomCredential),
144 }
145 
146 impl Credential {
147     /// Credential type of the underlying credential.
credential_type(&self) -> CredentialType148     pub fn credential_type(&self) -> CredentialType {
149         match self {
150             Credential::Basic(_) => CredentialType::BASIC,
151             #[cfg(feature = "x509")]
152             Credential::X509(_) => CredentialType::X509,
153             Credential::Custom(c) => c.credential_type,
154         }
155     }
156 
157     /// Convert this enum into a [`BasicCredential`]
158     ///
159     /// Returns `None` if this credential is any other type.
as_basic(&self) -> Option<&BasicCredential>160     pub fn as_basic(&self) -> Option<&BasicCredential> {
161         match self {
162             Credential::Basic(basic) => Some(basic),
163             _ => None,
164         }
165     }
166 
167     /// Convert this enum into a [`CertificateChain`]
168     ///
169     /// Returns `None` if this credential is any other type.
170     #[cfg(feature = "x509")]
as_x509(&self) -> Option<&CertificateChain>171     pub fn as_x509(&self) -> Option<&CertificateChain> {
172         match self {
173             Credential::X509(chain) => Some(chain),
174             _ => None,
175         }
176     }
177 
178     /// Convert this enum into a [`CustomCredential`]
179     ///
180     /// Returns `None` if this credential is any other type.
as_custom(&self) -> Option<&CustomCredential>181     pub fn as_custom(&self) -> Option<&CustomCredential> {
182         match self {
183             Credential::Custom(custom) => Some(custom),
184             _ => None,
185         }
186     }
187 }
188 
189 impl MlsSize for Credential {
mls_encoded_len(&self) -> usize190     fn mls_encoded_len(&self) -> usize {
191         let inner_len = match self {
192             Credential::Basic(c) => c.mls_encoded_len(),
193             #[cfg(feature = "x509")]
194             Credential::X509(c) => c.mls_encoded_len(),
195             Credential::Custom(c) => mls_rs_codec::byte_vec::mls_encoded_len(&c.data),
196         };
197 
198         self.credential_type().mls_encoded_len() + inner_len
199     }
200 }
201 
202 impl MlsEncode for Credential {
mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error>203     fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
204         self.credential_type().mls_encode(writer)?;
205 
206         match self {
207             Credential::Basic(c) => c.mls_encode(writer),
208             #[cfg(feature = "x509")]
209             Credential::X509(c) => c.mls_encode(writer),
210             Credential::Custom(c) => mls_rs_codec::byte_vec::mls_encode(&c.data, writer),
211         }
212     }
213 }
214 
215 impl MlsDecode for Credential {
mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error>216     fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
217         let credential_type = CredentialType::mls_decode(reader)?;
218 
219         Ok(match credential_type {
220             CredentialType::BASIC => Credential::Basic(BasicCredential::mls_decode(reader)?),
221             #[cfg(feature = "x509")]
222             CredentialType::X509 => Credential::X509(CertificateChain::mls_decode(reader)?),
223             custom => Credential::Custom(CustomCredential {
224                 credential_type: custom,
225                 data: mls_rs_codec::byte_vec::mls_decode(reader)?,
226             }),
227         })
228     }
229 }
230 
231 /// Trait that provides a conversion between an underlying credential type and
232 /// the [`Credential`] enum.
233 pub trait MlsCredential: Sized {
234     /// Conversion error type.
235     type Error;
236 
237     /// Credential type represented by this type.
credential_type() -> CredentialType238     fn credential_type() -> CredentialType;
239 
240     /// Function to convert this type into a [`Credential`] enum.
into_credential(self) -> Result<Credential, Self::Error>241     fn into_credential(self) -> Result<Credential, Self::Error>;
242 }
243