1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use alloc::vec::Vec; 6 use core::fmt::{self, Debug}; 7 use mls_rs_codec::{MlsEncode, MlsSize}; 8 use mls_rs_core::{ 9 crypto::{CipherSuiteProvider, HpkeCiphertext, HpkePublicKey, HpkeSecretKey}, 10 error::IntoAnyError, 11 }; 12 use zeroize::Zeroizing; 13 14 use crate::client::MlsError; 15 16 #[derive(Clone, MlsSize, MlsEncode)] 17 struct EncryptContext<'a> { 18 #[mls_codec(with = "mls_rs_codec::byte_vec")] 19 label: Vec<u8>, 20 #[mls_codec(with = "mls_rs_codec::byte_vec")] 21 context: &'a [u8], 22 } 23 24 impl Debug for EncryptContext<'_> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 26 f.debug_struct("EncryptContext") 27 .field("label", &mls_rs_core::debug::pretty_bytes(&self.label)) 28 .field("context", &mls_rs_core::debug::pretty_bytes(self.context)) 29 .finish() 30 } 31 } 32 33 impl<'a> EncryptContext<'a> { new(label: &str, context: &'a [u8]) -> Self34 pub fn new(label: &str, context: &'a [u8]) -> Self { 35 Self { 36 label: [b"MLS 1.0 ", label.as_bytes()].concat(), 37 context, 38 } 39 } 40 } 41 42 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] 43 #[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))] 44 #[cfg_attr( 45 all(not(target_arch = "wasm32"), mls_build_async), 46 maybe_async::must_be_async 47 )] 48 49 pub(crate) trait HpkeEncryptable: Sized { 50 const ENCRYPT_LABEL: &'static str; 51 encrypt<P: CipherSuiteProvider>( &self, cipher_suite_provider: &P, public_key: &HpkePublicKey, context: &[u8], ) -> Result<HpkeCiphertext, MlsError>52 async fn encrypt<P: CipherSuiteProvider>( 53 &self, 54 cipher_suite_provider: &P, 55 public_key: &HpkePublicKey, 56 context: &[u8], 57 ) -> Result<HpkeCiphertext, MlsError> { 58 let context = EncryptContext::new(Self::ENCRYPT_LABEL, context) 59 .mls_encode_to_vec() 60 .map(Zeroizing::new)?; 61 62 let content = self.get_bytes().map(Zeroizing::new)?; 63 64 cipher_suite_provider 65 .hpke_seal(public_key, &context, None, &content) 66 .await 67 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error())) 68 } 69 decrypt<P: CipherSuiteProvider>( cipher_suite_provider: &P, secret_key: &HpkeSecretKey, public_key: &HpkePublicKey, context: &[u8], ciphertext: &HpkeCiphertext, ) -> Result<Self, MlsError>70 async fn decrypt<P: CipherSuiteProvider>( 71 cipher_suite_provider: &P, 72 secret_key: &HpkeSecretKey, 73 public_key: &HpkePublicKey, 74 context: &[u8], 75 ciphertext: &HpkeCiphertext, 76 ) -> Result<Self, MlsError> { 77 let context = EncryptContext::new(Self::ENCRYPT_LABEL, context).mls_encode_to_vec()?; 78 79 let plaintext = cipher_suite_provider 80 .hpke_open(ciphertext, secret_key, public_key, &context, None) 81 .await 82 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?; 83 84 Self::from_bytes(plaintext.to_vec()) 85 } 86 from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError>87 fn from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError>; get_bytes(&self) -> Result<Vec<u8>, MlsError>88 fn get_bytes(&self) -> Result<Vec<u8>, MlsError>; 89 } 90 91 #[cfg(test)] 92 pub(crate) mod test_utils { 93 use alloc::{string::String, vec::Vec}; 94 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 95 use mls_rs_core::crypto::{CipherSuiteProvider, HpkeCiphertext}; 96 97 use crate::{client::MlsError, crypto::test_utils::try_test_cipher_suite_provider}; 98 99 use super::HpkeEncryptable; 100 101 #[derive(Debug, serde::Serialize, serde::Deserialize)] 102 pub struct HpkeInteropTestCase { 103 #[serde(with = "hex::serde", rename = "priv")] 104 secret: Vec<u8>, 105 #[serde(with = "hex::serde", rename = "pub")] 106 public: Vec<u8>, 107 label: String, 108 #[serde(with = "hex::serde")] 109 context: Vec<u8>, 110 #[serde(with = "hex::serde")] 111 plaintext: Vec<u8>, 112 #[serde(with = "hex::serde")] 113 kem_output: Vec<u8>, 114 #[serde(with = "hex::serde")] 115 ciphertext: Vec<u8>, 116 } 117 118 #[derive(Debug, serde::Serialize, serde::Deserialize)] 119 pub struct InteropTestCase { 120 cipher_suite: u16, 121 encrypt_with_label: HpkeInteropTestCase, 122 } 123 124 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_basic_crypto_test_vectors()125 async fn test_basic_crypto_test_vectors() { 126 // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/crypto-basics.json 127 let test_cases: Vec<InteropTestCase> = 128 load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new()); 129 130 for test_case in test_cases { 131 if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) { 132 test_case.encrypt_with_label.verify(&cs).await 133 } 134 } 135 } 136 137 #[derive(Clone, Debug, MlsSize, MlsEncode, MlsDecode)] 138 struct TestEncryptable(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>); 139 140 impl HpkeEncryptable for TestEncryptable { 141 const ENCRYPT_LABEL: &'static str = "EncryptWithLabel"; 142 from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError>143 fn from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError> { 144 Ok(Self(bytes)) 145 } 146 147 #[cfg_attr(coverage_nightly, coverage(off))] get_bytes(&self) -> Result<Vec<u8>, MlsError>148 fn get_bytes(&self) -> Result<Vec<u8>, MlsError> { 149 Ok(self.0.clone()) 150 } 151 } 152 153 impl HpkeInteropTestCase { 154 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] verify<P: CipherSuiteProvider>(&self, cs: &P)155 pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) { 156 let secret = self.secret.clone().into(); 157 let public = self.public.clone().into(); 158 159 let ciphertext = HpkeCiphertext { 160 kem_output: self.kem_output.clone(), 161 ciphertext: self.ciphertext.clone(), 162 }; 163 164 let computed_plaintext = 165 TestEncryptable::decrypt(cs, &secret, &public, &self.context, &ciphertext) 166 .await 167 .unwrap(); 168 169 assert_eq!(&computed_plaintext.0, &self.plaintext) 170 } 171 } 172 } 173