1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use mls_rs_core::{
6     crypto::{CipherSuite, HpkePublicKey, HpkeSecretKey},
7     error::IntoAnyError,
8 };
9 
10 use alloc::vec::Vec;
11 
12 #[cfg(feature = "mock")]
13 use mockall::automock;
14 
15 /// A trait that provides the required KEM functions
16 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
17 #[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
18 #[cfg_attr(
19     all(not(target_arch = "wasm32"), mls_build_async),
20     maybe_async::must_be_async
21 )]
22 #[cfg_attr(feature = "mock", automock(type Error = crate::mock::TestError;))]
23 pub trait KemType: Send + Sync {
24     type Error: IntoAnyError + Send + Sync;
25 
26     /// KEM Id, as specified in RFC 9180, Section 5.1 and Table 2.
kem_id(&self) -> u1627     fn kem_id(&self) -> u16;
28 
derive(&self, ikm: &[u8]) -> Result<(HpkeSecretKey, HpkePublicKey), Self::Error>29     async fn derive(&self, ikm: &[u8]) -> Result<(HpkeSecretKey, HpkePublicKey), Self::Error>;
generate(&self) -> Result<(HpkeSecretKey, HpkePublicKey), Self::Error>30     async fn generate(&self) -> Result<(HpkeSecretKey, HpkePublicKey), Self::Error>;
public_key_validate(&self, key: &HpkePublicKey) -> Result<(), Self::Error>31     fn public_key_validate(&self, key: &HpkePublicKey) -> Result<(), Self::Error>;
32 
encap(&self, remote_key: &HpkePublicKey) -> Result<KemResult, Self::Error>33     async fn encap(&self, remote_key: &HpkePublicKey) -> Result<KemResult, Self::Error>;
34 
decap( &self, enc: &[u8], secret_key: &HpkeSecretKey, local_public: &HpkePublicKey, ) -> Result<Vec<u8>, Self::Error>35     async fn decap(
36         &self,
37         enc: &[u8],
38         secret_key: &HpkeSecretKey,
39         local_public: &HpkePublicKey,
40     ) -> Result<Vec<u8>, Self::Error>;
41 }
42 
43 /// Struct to represent the output of the kem [encap](KemType::encap) function
44 pub struct KemResult {
45     pub shared_secret: Vec<u8>,
46     pub enc: Vec<u8>,
47 }
48 
49 impl KemResult {
new(shared_secret: Vec<u8>, enc: Vec<u8>) -> Self50     pub fn new(shared_secret: Vec<u8>, enc: Vec<u8>) -> Self {
51         Self { shared_secret, enc }
52     }
53 
shared_secret(&self) -> &[u8]54     pub fn shared_secret(&self) -> &[u8] {
55         &self.shared_secret
56     }
57 
58     /// Returns the ciphertext encapsulating the shared secret.
enc(&self) -> &[u8]59     pub fn enc(&self) -> &[u8] {
60         &self.enc
61     }
62 }
63 
64 /// Kem identifiers for HPKE
65 #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
66 #[repr(u16)]
67 #[non_exhaustive]
68 pub enum KemId {
69     DhKemP256Sha256 = 0x0010,
70     DhKemP384Sha384 = 0x0011,
71     DhKemP521Sha512 = 0x0012,
72     DhKemX25519Sha256 = 0x0020,
73     DhKemX448Sha512 = 0x0021,
74 }
75 
76 impl KemId {
new(cipher_suite: CipherSuite) -> Option<Self>77     pub fn new(cipher_suite: CipherSuite) -> Option<Self> {
78         match cipher_suite {
79             CipherSuite::CURVE25519_AES128 | CipherSuite::CURVE25519_CHACHA => {
80                 Some(KemId::DhKemX25519Sha256)
81             }
82             CipherSuite::P256_AES128 => Some(KemId::DhKemP256Sha256),
83             CipherSuite::CURVE448_AES256 | CipherSuite::CURVE448_CHACHA => {
84                 Some(KemId::DhKemX448Sha512)
85             }
86             CipherSuite::P384_AES256 => Some(KemId::DhKemP384Sha384),
87             CipherSuite::P521_AES256 => Some(KemId::DhKemP521Sha512),
88             _ => None,
89         }
90     }
91 
n_secret(&self) -> usize92     pub fn n_secret(&self) -> usize {
93         match self {
94             KemId::DhKemP256Sha256 => 32,
95             KemId::DhKemP384Sha384 => 48,
96             KemId::DhKemP521Sha512 => 64,
97             KemId::DhKemX25519Sha256 => 32,
98             KemId::DhKemX448Sha512 => 64,
99         }
100     }
101 }
102