1*60b67249SAndroid Build Coastguard Worker // Copyright 2024 Google LLC
2*60b67249SAndroid Build Coastguard Worker //
3*60b67249SAndroid Build Coastguard Worker // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4*60b67249SAndroid Build Coastguard Worker // use this file except in compliance with the License. You may obtain a copy of
5*60b67249SAndroid Build Coastguard Worker // the License at
6*60b67249SAndroid Build Coastguard Worker //
7*60b67249SAndroid Build Coastguard Worker // https://www.apache.org/licenses/LICENSE-2.0
8*60b67249SAndroid Build Coastguard Worker //
9*60b67249SAndroid Build Coastguard Worker // Unless required by applicable law or agreed to in writing, software
10*60b67249SAndroid Build Coastguard Worker // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11*60b67249SAndroid Build Coastguard Worker // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12*60b67249SAndroid Build Coastguard Worker // License for the specific language governing permissions and limitations under
13*60b67249SAndroid Build Coastguard Worker // the License.
14*60b67249SAndroid Build Coastguard Worker
15*60b67249SAndroid Build Coastguard Worker //! An encrypted session implementation which uses
16*60b67249SAndroid Build Coastguard Worker //! Noise_NK_X25519_AESGCM_SHA512 and Noise_NNpsk0_X25519_AESGCM_SHA512.
17*60b67249SAndroid Build Coastguard Worker
18*60b67249SAndroid Build Coastguard Worker use crate::crypto::{
19*60b67249SAndroid Build Coastguard Worker Commit, Counter, DhPrivateKey, DhPublicKey, HandshakeMessage,
20*60b67249SAndroid Build Coastguard Worker HandshakePayload, Hash, SessionCrypto,
21*60b67249SAndroid Build Coastguard Worker };
22*60b67249SAndroid Build Coastguard Worker use crate::error::{DpeResult, ErrCode};
23*60b67249SAndroid Build Coastguard Worker use crate::memory::Message;
24*60b67249SAndroid Build Coastguard Worker use core::marker::PhantomData;
25*60b67249SAndroid Build Coastguard Worker use log::{debug, error};
26*60b67249SAndroid Build Coastguard Worker use noise_protocol::{HandshakeStateBuilder, Hash as NoiseHash, U8Array};
27*60b67249SAndroid Build Coastguard Worker
28*60b67249SAndroid Build Coastguard Worker impl From<noise_protocol::Error> for ErrCode {
from(_err: noise_protocol::Error) -> Self29*60b67249SAndroid Build Coastguard Worker fn from(_err: noise_protocol::Error) -> Self {
30*60b67249SAndroid Build Coastguard Worker ErrCode::InvalidArgument
31*60b67249SAndroid Build Coastguard Worker }
32*60b67249SAndroid Build Coastguard Worker }
33*60b67249SAndroid Build Coastguard Worker
34*60b67249SAndroid Build Coastguard Worker impl<NoiseHash> From<&NoiseHash> for Hash
35*60b67249SAndroid Build Coastguard Worker where
36*60b67249SAndroid Build Coastguard Worker NoiseHash: U8Array,
37*60b67249SAndroid Build Coastguard Worker {
from(value: &NoiseHash) -> Self38*60b67249SAndroid Build Coastguard Worker fn from(value: &NoiseHash) -> Self {
39*60b67249SAndroid Build Coastguard Worker // The Noise hash size may not match HASH_SIZE.
40*60b67249SAndroid Build Coastguard Worker Hash::from_slice_infallible(value.as_slice())
41*60b67249SAndroid Build Coastguard Worker }
42*60b67249SAndroid Build Coastguard Worker }
43*60b67249SAndroid Build Coastguard Worker
44*60b67249SAndroid Build Coastguard Worker /// A cipher state type that can be used as a
45*60b67249SAndroid Build Coastguard Worker /// [`SessionCipherState`](crate::crypto::SessionCrypto::SessionCipherState).
46*60b67249SAndroid Build Coastguard Worker pub struct NoiseCipherState<C: noise_protocol::Cipher> {
47*60b67249SAndroid Build Coastguard Worker k: C::Key,
48*60b67249SAndroid Build Coastguard Worker n: u64,
49*60b67249SAndroid Build Coastguard Worker n_staged: u64,
50*60b67249SAndroid Build Coastguard Worker }
51*60b67249SAndroid Build Coastguard Worker
52*60b67249SAndroid Build Coastguard Worker impl<C: noise_protocol::Cipher> Clone for NoiseCipherState<C> {
clone(&self) -> Self53*60b67249SAndroid Build Coastguard Worker fn clone(&self) -> Self {
54*60b67249SAndroid Build Coastguard Worker Self { k: self.k.clone(), n: self.n, n_staged: self.n_staged }
55*60b67249SAndroid Build Coastguard Worker }
56*60b67249SAndroid Build Coastguard Worker }
57*60b67249SAndroid Build Coastguard Worker
58*60b67249SAndroid Build Coastguard Worker impl<C: noise_protocol::Cipher> Default for NoiseCipherState<C> {
default() -> Self59*60b67249SAndroid Build Coastguard Worker fn default() -> Self {
60*60b67249SAndroid Build Coastguard Worker Self { k: C::Key::new(), n: 0, n_staged: 0 }
61*60b67249SAndroid Build Coastguard Worker }
62*60b67249SAndroid Build Coastguard Worker }
63*60b67249SAndroid Build Coastguard Worker
64*60b67249SAndroid Build Coastguard Worker impl<C: noise_protocol::Cipher> core::fmt::Debug for NoiseCipherState<C> {
fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result65*60b67249SAndroid Build Coastguard Worker fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
66*60b67249SAndroid Build Coastguard Worker write!(f, "k: redacted, n: {}", self.n)?;
67*60b67249SAndroid Build Coastguard Worker Ok(())
68*60b67249SAndroid Build Coastguard Worker }
69*60b67249SAndroid Build Coastguard Worker }
70*60b67249SAndroid Build Coastguard Worker
71*60b67249SAndroid Build Coastguard Worker impl<C: noise_protocol::Cipher> core::hash::Hash for NoiseCipherState<C> {
hash<H: core::hash::Hasher>(&self, state: &mut H)72*60b67249SAndroid Build Coastguard Worker fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
73*60b67249SAndroid Build Coastguard Worker self.k.as_slice().hash(state);
74*60b67249SAndroid Build Coastguard Worker self.n.hash(state);
75*60b67249SAndroid Build Coastguard Worker self.n_staged.hash(state);
76*60b67249SAndroid Build Coastguard Worker }
77*60b67249SAndroid Build Coastguard Worker }
78*60b67249SAndroid Build Coastguard Worker
79*60b67249SAndroid Build Coastguard Worker #[cfg(test)]
80*60b67249SAndroid Build Coastguard Worker impl<C: noise_protocol::Cipher> PartialEq for NoiseCipherState<C> {
eq(&self, other: &Self) -> bool81*60b67249SAndroid Build Coastguard Worker fn eq(&self, other: &Self) -> bool {
82*60b67249SAndroid Build Coastguard Worker self.k.as_slice() == other.k.as_slice()
83*60b67249SAndroid Build Coastguard Worker && self.n == other.n
84*60b67249SAndroid Build Coastguard Worker && self.n_staged == other.n_staged
85*60b67249SAndroid Build Coastguard Worker }
86*60b67249SAndroid Build Coastguard Worker }
87*60b67249SAndroid Build Coastguard Worker
88*60b67249SAndroid Build Coastguard Worker #[cfg(test)]
89*60b67249SAndroid Build Coastguard Worker impl<C: noise_protocol::Cipher> Eq for NoiseCipherState<C> {}
90*60b67249SAndroid Build Coastguard Worker
91*60b67249SAndroid Build Coastguard Worker impl<C: noise_protocol::Cipher> Counter for NoiseCipherState<C> {
n(&self) -> u6492*60b67249SAndroid Build Coastguard Worker fn n(&self) -> u64 {
93*60b67249SAndroid Build Coastguard Worker self.n
94*60b67249SAndroid Build Coastguard Worker }
set_n(&mut self, n: u64)95*60b67249SAndroid Build Coastguard Worker fn set_n(&mut self, n: u64) {
96*60b67249SAndroid Build Coastguard Worker self.n = n;
97*60b67249SAndroid Build Coastguard Worker }
98*60b67249SAndroid Build Coastguard Worker }
99*60b67249SAndroid Build Coastguard Worker
100*60b67249SAndroid Build Coastguard Worker impl<C: noise_protocol::Cipher> Commit for NoiseCipherState<C> {
101*60b67249SAndroid Build Coastguard Worker // Called when an encrypted message is finalized to commit the new cipher
102*60b67249SAndroid Build Coastguard Worker // state.
commit(&mut self)103*60b67249SAndroid Build Coastguard Worker fn commit(&mut self) {
104*60b67249SAndroid Build Coastguard Worker self.n = self.n_staged;
105*60b67249SAndroid Build Coastguard Worker }
106*60b67249SAndroid Build Coastguard Worker }
107*60b67249SAndroid Build Coastguard Worker
108*60b67249SAndroid Build Coastguard Worker impl<C: noise_protocol::Cipher> From<&noise_protocol::CipherState<C>>
109*60b67249SAndroid Build Coastguard Worker for NoiseCipherState<C>
110*60b67249SAndroid Build Coastguard Worker {
from(cs: &noise_protocol::CipherState<C>) -> Self111*60b67249SAndroid Build Coastguard Worker fn from(cs: &noise_protocol::CipherState<C>) -> Self {
112*60b67249SAndroid Build Coastguard Worker let (key, counter) = cs.clone().extract();
113*60b67249SAndroid Build Coastguard Worker NoiseCipherState { k: key, n: counter, n_staged: counter }
114*60b67249SAndroid Build Coastguard Worker }
115*60b67249SAndroid Build Coastguard Worker }
116*60b67249SAndroid Build Coastguard Worker
117*60b67249SAndroid Build Coastguard Worker /// Returns the public key corresponding to a given `dh_private_key`.
get_dh_public_key<D: noise_protocol::DH>( dh_private_key: &DhPrivateKey, ) -> DpeResult<DhPublicKey>118*60b67249SAndroid Build Coastguard Worker pub fn get_dh_public_key<D: noise_protocol::DH>(
119*60b67249SAndroid Build Coastguard Worker dh_private_key: &DhPrivateKey,
120*60b67249SAndroid Build Coastguard Worker ) -> DpeResult<DhPublicKey> {
121*60b67249SAndroid Build Coastguard Worker DhPublicKey::from_slice(
122*60b67249SAndroid Build Coastguard Worker D::pubkey(&D::Key::from_slice(dh_private_key.as_slice())).as_slice(),
123*60b67249SAndroid Build Coastguard Worker )
124*60b67249SAndroid Build Coastguard Worker }
125*60b67249SAndroid Build Coastguard Worker
126*60b67249SAndroid Build Coastguard Worker /// A trait representing [`NoiseSessionCrypto`] dependencies.
127*60b67249SAndroid Build Coastguard Worker pub trait NoiseCryptoDeps {
128*60b67249SAndroid Build Coastguard Worker /// Cipher type
129*60b67249SAndroid Build Coastguard Worker type Cipher: noise_protocol::Cipher;
130*60b67249SAndroid Build Coastguard Worker /// DH type
131*60b67249SAndroid Build Coastguard Worker type DH: noise_protocol::DH;
132*60b67249SAndroid Build Coastguard Worker /// Hash type
133*60b67249SAndroid Build Coastguard Worker type Hash: noise_protocol::Hash;
134*60b67249SAndroid Build Coastguard Worker }
135*60b67249SAndroid Build Coastguard Worker
136*60b67249SAndroid Build Coastguard Worker /// A Noise implementation of the [`SessionCrypto`] trait.
137*60b67249SAndroid Build Coastguard Worker pub struct NoiseSessionCrypto<D: NoiseCryptoDeps> {
138*60b67249SAndroid Build Coastguard Worker #[allow(dead_code)]
139*60b67249SAndroid Build Coastguard Worker phantom: PhantomData<D>,
140*60b67249SAndroid Build Coastguard Worker }
141*60b67249SAndroid Build Coastguard Worker
142*60b67249SAndroid Build Coastguard Worker impl<D> Clone for NoiseSessionCrypto<D>
143*60b67249SAndroid Build Coastguard Worker where
144*60b67249SAndroid Build Coastguard Worker D: NoiseCryptoDeps,
145*60b67249SAndroid Build Coastguard Worker {
clone(&self) -> Self146*60b67249SAndroid Build Coastguard Worker fn clone(&self) -> Self {
147*60b67249SAndroid Build Coastguard Worker Self { phantom: Default::default() }
148*60b67249SAndroid Build Coastguard Worker }
149*60b67249SAndroid Build Coastguard Worker }
150*60b67249SAndroid Build Coastguard Worker
151*60b67249SAndroid Build Coastguard Worker impl<D> Default for NoiseSessionCrypto<D>
152*60b67249SAndroid Build Coastguard Worker where
153*60b67249SAndroid Build Coastguard Worker D: NoiseCryptoDeps,
154*60b67249SAndroid Build Coastguard Worker {
default() -> Self155*60b67249SAndroid Build Coastguard Worker fn default() -> Self {
156*60b67249SAndroid Build Coastguard Worker Self { phantom: Default::default() }
157*60b67249SAndroid Build Coastguard Worker }
158*60b67249SAndroid Build Coastguard Worker }
159*60b67249SAndroid Build Coastguard Worker
160*60b67249SAndroid Build Coastguard Worker impl<D> core::fmt::Debug for NoiseSessionCrypto<D>
161*60b67249SAndroid Build Coastguard Worker where
162*60b67249SAndroid Build Coastguard Worker D: NoiseCryptoDeps,
163*60b67249SAndroid Build Coastguard Worker {
fmt(&self, _: &mut core::fmt::Formatter<'_>) -> core::fmt::Result164*60b67249SAndroid Build Coastguard Worker fn fmt(&self, _: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
165*60b67249SAndroid Build Coastguard Worker Ok(())
166*60b67249SAndroid Build Coastguard Worker }
167*60b67249SAndroid Build Coastguard Worker }
168*60b67249SAndroid Build Coastguard Worker
169*60b67249SAndroid Build Coastguard Worker impl<D> core::hash::Hash for NoiseSessionCrypto<D>
170*60b67249SAndroid Build Coastguard Worker where
171*60b67249SAndroid Build Coastguard Worker D: NoiseCryptoDeps,
172*60b67249SAndroid Build Coastguard Worker {
hash<Hr: core::hash::Hasher>(&self, _: &mut Hr)173*60b67249SAndroid Build Coastguard Worker fn hash<Hr: core::hash::Hasher>(&self, _: &mut Hr) {}
174*60b67249SAndroid Build Coastguard Worker }
175*60b67249SAndroid Build Coastguard Worker
176*60b67249SAndroid Build Coastguard Worker impl<D> PartialEq for NoiseSessionCrypto<D>
177*60b67249SAndroid Build Coastguard Worker where
178*60b67249SAndroid Build Coastguard Worker D: NoiseCryptoDeps,
179*60b67249SAndroid Build Coastguard Worker {
eq(&self, _: &Self) -> bool180*60b67249SAndroid Build Coastguard Worker fn eq(&self, _: &Self) -> bool {
181*60b67249SAndroid Build Coastguard Worker true
182*60b67249SAndroid Build Coastguard Worker }
183*60b67249SAndroid Build Coastguard Worker }
184*60b67249SAndroid Build Coastguard Worker
185*60b67249SAndroid Build Coastguard Worker impl<D> Eq for NoiseSessionCrypto<D> where D: NoiseCryptoDeps {}
186*60b67249SAndroid Build Coastguard Worker
187*60b67249SAndroid Build Coastguard Worker impl<D> SessionCrypto for NoiseSessionCrypto<D>
188*60b67249SAndroid Build Coastguard Worker where
189*60b67249SAndroid Build Coastguard Worker D: NoiseCryptoDeps,
190*60b67249SAndroid Build Coastguard Worker {
191*60b67249SAndroid Build Coastguard Worker type SessionCipherState = NoiseCipherState<D::Cipher>;
192*60b67249SAndroid Build Coastguard Worker
193*60b67249SAndroid Build Coastguard Worker /// Implements the responder role of a Noise_NK handshake.
new_session_handshake( static_dh_key: &DhPrivateKey, initiator_handshake: &HandshakeMessage, payload: &HandshakePayload, responder_handshake: &mut HandshakeMessage, decrypt_cipher_state: &mut NoiseCipherState<D::Cipher>, encrypt_cipher_state: &mut NoiseCipherState<D::Cipher>, psk_seed: &mut Hash, ) -> DpeResult<()>194*60b67249SAndroid Build Coastguard Worker fn new_session_handshake(
195*60b67249SAndroid Build Coastguard Worker static_dh_key: &DhPrivateKey,
196*60b67249SAndroid Build Coastguard Worker initiator_handshake: &HandshakeMessage,
197*60b67249SAndroid Build Coastguard Worker payload: &HandshakePayload,
198*60b67249SAndroid Build Coastguard Worker responder_handshake: &mut HandshakeMessage,
199*60b67249SAndroid Build Coastguard Worker decrypt_cipher_state: &mut NoiseCipherState<D::Cipher>,
200*60b67249SAndroid Build Coastguard Worker encrypt_cipher_state: &mut NoiseCipherState<D::Cipher>,
201*60b67249SAndroid Build Coastguard Worker psk_seed: &mut Hash,
202*60b67249SAndroid Build Coastguard Worker ) -> DpeResult<()> {
203*60b67249SAndroid Build Coastguard Worker #[allow(unused_results)]
204*60b67249SAndroid Build Coastguard Worker let mut handshake: noise_protocol::HandshakeState<
205*60b67249SAndroid Build Coastguard Worker D::DH,
206*60b67249SAndroid Build Coastguard Worker D::Cipher,
207*60b67249SAndroid Build Coastguard Worker D::Hash,
208*60b67249SAndroid Build Coastguard Worker > = {
209*60b67249SAndroid Build Coastguard Worker let mut builder = HandshakeStateBuilder::new();
210*60b67249SAndroid Build Coastguard Worker builder.set_pattern(noise_protocol::patterns::noise_nk());
211*60b67249SAndroid Build Coastguard Worker builder.set_is_initiator(false);
212*60b67249SAndroid Build Coastguard Worker builder.set_prologue(&[]);
213*60b67249SAndroid Build Coastguard Worker builder.set_s(<D::DH as noise_protocol::DH>::Key::from_slice(
214*60b67249SAndroid Build Coastguard Worker static_dh_key.as_slice(),
215*60b67249SAndroid Build Coastguard Worker ));
216*60b67249SAndroid Build Coastguard Worker builder.build_handshake_state()
217*60b67249SAndroid Build Coastguard Worker };
218*60b67249SAndroid Build Coastguard Worker handshake.read_message(initiator_handshake.as_slice(), &mut [])?;
219*60b67249SAndroid Build Coastguard Worker handshake.write_message(
220*60b67249SAndroid Build Coastguard Worker payload.as_slice(),
221*60b67249SAndroid Build Coastguard Worker responder_handshake.as_mut_sized(
222*60b67249SAndroid Build Coastguard Worker handshake.get_next_message_overhead() + payload.len(),
223*60b67249SAndroid Build Coastguard Worker )?,
224*60b67249SAndroid Build Coastguard Worker )?;
225*60b67249SAndroid Build Coastguard Worker assert!(handshake.completed());
226*60b67249SAndroid Build Coastguard Worker let ciphers = handshake.get_ciphers();
227*60b67249SAndroid Build Coastguard Worker *decrypt_cipher_state = (&ciphers.0).into();
228*60b67249SAndroid Build Coastguard Worker *encrypt_cipher_state = (&ciphers.1).into();
229*60b67249SAndroid Build Coastguard Worker debug!("get_hash");
230*60b67249SAndroid Build Coastguard Worker *psk_seed = Hash::from_slice(handshake.get_hash())?;
231*60b67249SAndroid Build Coastguard Worker Ok(())
232*60b67249SAndroid Build Coastguard Worker }
233*60b67249SAndroid Build Coastguard Worker
234*60b67249SAndroid Build Coastguard Worker /// Implements the responder role of a Noise_NNpsk0 handshake.
derive_session_handshake( psk: &Hash, initiator_handshake: &HandshakeMessage, payload: &HandshakePayload, responder_handshake: &mut HandshakeMessage, decrypt_cipher_state: &mut NoiseCipherState<D::Cipher>, encrypt_cipher_state: &mut NoiseCipherState<D::Cipher>, psk_seed: &mut Hash, ) -> DpeResult<()>235*60b67249SAndroid Build Coastguard Worker fn derive_session_handshake(
236*60b67249SAndroid Build Coastguard Worker psk: &Hash,
237*60b67249SAndroid Build Coastguard Worker initiator_handshake: &HandshakeMessage,
238*60b67249SAndroid Build Coastguard Worker payload: &HandshakePayload,
239*60b67249SAndroid Build Coastguard Worker responder_handshake: &mut HandshakeMessage,
240*60b67249SAndroid Build Coastguard Worker decrypt_cipher_state: &mut NoiseCipherState<D::Cipher>,
241*60b67249SAndroid Build Coastguard Worker encrypt_cipher_state: &mut NoiseCipherState<D::Cipher>,
242*60b67249SAndroid Build Coastguard Worker psk_seed: &mut Hash,
243*60b67249SAndroid Build Coastguard Worker ) -> DpeResult<()> {
244*60b67249SAndroid Build Coastguard Worker #[allow(unused_results)]
245*60b67249SAndroid Build Coastguard Worker let mut handshake: noise_protocol::HandshakeState<
246*60b67249SAndroid Build Coastguard Worker D::DH,
247*60b67249SAndroid Build Coastguard Worker D::Cipher,
248*60b67249SAndroid Build Coastguard Worker D::Hash,
249*60b67249SAndroid Build Coastguard Worker > = {
250*60b67249SAndroid Build Coastguard Worker let mut builder = HandshakeStateBuilder::new();
251*60b67249SAndroid Build Coastguard Worker builder.set_pattern(noise_protocol::patterns::noise_nn_psk0());
252*60b67249SAndroid Build Coastguard Worker builder.set_is_initiator(false);
253*60b67249SAndroid Build Coastguard Worker builder.set_prologue(&[]);
254*60b67249SAndroid Build Coastguard Worker builder.build_handshake_state()
255*60b67249SAndroid Build Coastguard Worker };
256*60b67249SAndroid Build Coastguard Worker handshake
257*60b67249SAndroid Build Coastguard Worker .push_psk(psk.as_slice().get(..32).ok_or(ErrCode::InternalError)?);
258*60b67249SAndroid Build Coastguard Worker handshake.read_message(initiator_handshake.as_slice(), &mut [])?;
259*60b67249SAndroid Build Coastguard Worker handshake.write_message(
260*60b67249SAndroid Build Coastguard Worker payload.as_slice(),
261*60b67249SAndroid Build Coastguard Worker responder_handshake.as_mut_sized(
262*60b67249SAndroid Build Coastguard Worker handshake.get_next_message_overhead() + payload.len(),
263*60b67249SAndroid Build Coastguard Worker )?,
264*60b67249SAndroid Build Coastguard Worker )?;
265*60b67249SAndroid Build Coastguard Worker let ciphers = handshake.get_ciphers();
266*60b67249SAndroid Build Coastguard Worker *decrypt_cipher_state = (&ciphers.0).into();
267*60b67249SAndroid Build Coastguard Worker *encrypt_cipher_state = (&ciphers.1).into();
268*60b67249SAndroid Build Coastguard Worker *psk_seed = Hash::from_slice(handshake.get_hash())?;
269*60b67249SAndroid Build Coastguard Worker Ok(())
270*60b67249SAndroid Build Coastguard Worker }
271*60b67249SAndroid Build Coastguard Worker
272*60b67249SAndroid Build Coastguard Worker /// Encrypts a Noise transport message in place.
session_encrypt( cipher_state: &mut NoiseCipherState<D::Cipher>, in_place_buffer: &mut Message, ) -> DpeResult<()>273*60b67249SAndroid Build Coastguard Worker fn session_encrypt(
274*60b67249SAndroid Build Coastguard Worker cipher_state: &mut NoiseCipherState<D::Cipher>,
275*60b67249SAndroid Build Coastguard Worker in_place_buffer: &mut Message,
276*60b67249SAndroid Build Coastguard Worker ) -> DpeResult<()> {
277*60b67249SAndroid Build Coastguard Worker let mut cs = noise_protocol::CipherState::<D::Cipher>::new(
278*60b67249SAndroid Build Coastguard Worker cipher_state.k.as_slice(),
279*60b67249SAndroid Build Coastguard Worker cipher_state.n,
280*60b67249SAndroid Build Coastguard Worker );
281*60b67249SAndroid Build Coastguard Worker let plaintext_len = in_place_buffer.len();
282*60b67249SAndroid Build Coastguard Worker let _ = cs.encrypt_in_place(
283*60b67249SAndroid Build Coastguard Worker in_place_buffer.as_mut_sized(
284*60b67249SAndroid Build Coastguard Worker plaintext_len
285*60b67249SAndroid Build Coastguard Worker + <D::Cipher as noise_protocol::Cipher>::tag_len(),
286*60b67249SAndroid Build Coastguard Worker )?,
287*60b67249SAndroid Build Coastguard Worker plaintext_len,
288*60b67249SAndroid Build Coastguard Worker );
289*60b67249SAndroid Build Coastguard Worker // Encrypting a message is usually not the final step in preparing
290*60b67249SAndroid Build Coastguard Worker // the message for transport. If a subsequent step fails, it is
291*60b67249SAndroid Build Coastguard Worker // better for 'n' to remain unchanged so we don't get out of sync.
292*60b67249SAndroid Build Coastguard Worker (_, cipher_state.n_staged) = cs.extract();
293*60b67249SAndroid Build Coastguard Worker Ok(())
294*60b67249SAndroid Build Coastguard Worker }
295*60b67249SAndroid Build Coastguard Worker
296*60b67249SAndroid Build Coastguard Worker /// Decrypts a Noise transport message in place.
session_decrypt( cipher_state: &mut NoiseCipherState<D::Cipher>, in_place_buffer: &mut Message, ) -> DpeResult<()>297*60b67249SAndroid Build Coastguard Worker fn session_decrypt(
298*60b67249SAndroid Build Coastguard Worker cipher_state: &mut NoiseCipherState<D::Cipher>,
299*60b67249SAndroid Build Coastguard Worker in_place_buffer: &mut Message,
300*60b67249SAndroid Build Coastguard Worker ) -> DpeResult<()> {
301*60b67249SAndroid Build Coastguard Worker let mut cs = noise_protocol::CipherState::<D::Cipher>::new(
302*60b67249SAndroid Build Coastguard Worker cipher_state.k.as_slice(),
303*60b67249SAndroid Build Coastguard Worker cipher_state.n,
304*60b67249SAndroid Build Coastguard Worker );
305*60b67249SAndroid Build Coastguard Worker let ciphertext_len = in_place_buffer.len();
306*60b67249SAndroid Build Coastguard Worker let plaintext_len = match cs
307*60b67249SAndroid Build Coastguard Worker .decrypt_in_place(in_place_buffer.vec.as_mut(), ciphertext_len)
308*60b67249SAndroid Build Coastguard Worker {
309*60b67249SAndroid Build Coastguard Worker Ok(length) => length,
310*60b67249SAndroid Build Coastguard Worker _ => {
311*60b67249SAndroid Build Coastguard Worker error!("Session decrypt failed");
312*60b67249SAndroid Build Coastguard Worker return Err(ErrCode::InvalidCommand);
313*60b67249SAndroid Build Coastguard Worker }
314*60b67249SAndroid Build Coastguard Worker };
315*60b67249SAndroid Build Coastguard Worker in_place_buffer.vec.truncate(plaintext_len);
316*60b67249SAndroid Build Coastguard Worker (_, cipher_state.n) = cs.extract();
317*60b67249SAndroid Build Coastguard Worker Ok(())
318*60b67249SAndroid Build Coastguard Worker }
319*60b67249SAndroid Build Coastguard Worker
320*60b67249SAndroid Build Coastguard Worker /// Derives a responder-side PSK.
derive_psk_from_session( psk_seed: &Hash, decrypt_cipher_state: &NoiseCipherState<D::Cipher>, encrypt_cipher_state: &NoiseCipherState<D::Cipher>, ) -> DpeResult<Hash>321*60b67249SAndroid Build Coastguard Worker fn derive_psk_from_session(
322*60b67249SAndroid Build Coastguard Worker psk_seed: &Hash,
323*60b67249SAndroid Build Coastguard Worker decrypt_cipher_state: &NoiseCipherState<D::Cipher>,
324*60b67249SAndroid Build Coastguard Worker encrypt_cipher_state: &NoiseCipherState<D::Cipher>,
325*60b67249SAndroid Build Coastguard Worker ) -> DpeResult<Hash> {
326*60b67249SAndroid Build Coastguard Worker let mut hasher: D::Hash = Default::default();
327*60b67249SAndroid Build Coastguard Worker hasher.input(psk_seed.as_slice());
328*60b67249SAndroid Build Coastguard Worker // Use the decrypt state as it was before we decrypted the current
329*60b67249SAndroid Build Coastguard Worker // command message. This allows clients to compute the PSK using
330*60b67249SAndroid Build Coastguard Worker // the cipher states as they are before the client sends the
331*60b67249SAndroid Build Coastguard Worker // command.
332*60b67249SAndroid Build Coastguard Worker hasher.input(&(decrypt_cipher_state.n() - 1).to_le_bytes());
333*60b67249SAndroid Build Coastguard Worker hasher.input(&encrypt_cipher_state.n().to_le_bytes());
334*60b67249SAndroid Build Coastguard Worker Ok((&hasher.result()).into())
335*60b67249SAndroid Build Coastguard Worker }
336*60b67249SAndroid Build Coastguard Worker }
337*60b67249SAndroid Build Coastguard Worker
338*60b67249SAndroid Build Coastguard Worker /// A SessionClient implements the initiator side of an encrypted session. A
339*60b67249SAndroid Build Coastguard Worker /// DPE does not use this itself, it is useful for clients and testing.
340*60b67249SAndroid Build Coastguard Worker pub struct SessionClient<D>
341*60b67249SAndroid Build Coastguard Worker where
342*60b67249SAndroid Build Coastguard Worker D: NoiseCryptoDeps,
343*60b67249SAndroid Build Coastguard Worker {
344*60b67249SAndroid Build Coastguard Worker handshake_state:
345*60b67249SAndroid Build Coastguard Worker Option<noise_protocol::HandshakeState<D::DH, D::Cipher, D::Hash>>,
346*60b67249SAndroid Build Coastguard Worker /// Cipher state for encrypting messages to a DPE.
347*60b67249SAndroid Build Coastguard Worker pub encrypt_cipher_state: NoiseCipherState<D::Cipher>,
348*60b67249SAndroid Build Coastguard Worker /// Cipher state for decrypting messages from a DPE.
349*60b67249SAndroid Build Coastguard Worker pub decrypt_cipher_state: NoiseCipherState<D::Cipher>,
350*60b67249SAndroid Build Coastguard Worker /// PSK seed for deriving sessions. See [`derive_psk`].
351*60b67249SAndroid Build Coastguard Worker ///
352*60b67249SAndroid Build Coastguard Worker /// [`derive_psk`]: #method.derive_psk
353*60b67249SAndroid Build Coastguard Worker pub psk_seed: Hash,
354*60b67249SAndroid Build Coastguard Worker }
355*60b67249SAndroid Build Coastguard Worker
356*60b67249SAndroid Build Coastguard Worker impl<D> Clone for SessionClient<D>
357*60b67249SAndroid Build Coastguard Worker where
358*60b67249SAndroid Build Coastguard Worker D: NoiseCryptoDeps,
359*60b67249SAndroid Build Coastguard Worker {
clone(&self) -> Self360*60b67249SAndroid Build Coastguard Worker fn clone(&self) -> Self {
361*60b67249SAndroid Build Coastguard Worker Self {
362*60b67249SAndroid Build Coastguard Worker handshake_state: self.handshake_state.clone(),
363*60b67249SAndroid Build Coastguard Worker encrypt_cipher_state: self.encrypt_cipher_state.clone(),
364*60b67249SAndroid Build Coastguard Worker decrypt_cipher_state: self.decrypt_cipher_state.clone(),
365*60b67249SAndroid Build Coastguard Worker psk_seed: self.psk_seed.clone(),
366*60b67249SAndroid Build Coastguard Worker }
367*60b67249SAndroid Build Coastguard Worker }
368*60b67249SAndroid Build Coastguard Worker }
369*60b67249SAndroid Build Coastguard Worker
370*60b67249SAndroid Build Coastguard Worker impl<D> Default for SessionClient<D>
371*60b67249SAndroid Build Coastguard Worker where
372*60b67249SAndroid Build Coastguard Worker D: NoiseCryptoDeps,
373*60b67249SAndroid Build Coastguard Worker {
default() -> Self374*60b67249SAndroid Build Coastguard Worker fn default() -> Self {
375*60b67249SAndroid Build Coastguard Worker Self::new()
376*60b67249SAndroid Build Coastguard Worker }
377*60b67249SAndroid Build Coastguard Worker }
378*60b67249SAndroid Build Coastguard Worker
379*60b67249SAndroid Build Coastguard Worker impl<D> core::fmt::Debug for SessionClient<D>
380*60b67249SAndroid Build Coastguard Worker where
381*60b67249SAndroid Build Coastguard Worker D: NoiseCryptoDeps,
382*60b67249SAndroid Build Coastguard Worker {
fmt(&self, _: &mut core::fmt::Formatter<'_>) -> core::fmt::Result383*60b67249SAndroid Build Coastguard Worker fn fmt(&self, _: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
384*60b67249SAndroid Build Coastguard Worker Ok(())
385*60b67249SAndroid Build Coastguard Worker }
386*60b67249SAndroid Build Coastguard Worker }
387*60b67249SAndroid Build Coastguard Worker
388*60b67249SAndroid Build Coastguard Worker impl<D> SessionClient<D>
389*60b67249SAndroid Build Coastguard Worker where
390*60b67249SAndroid Build Coastguard Worker D: NoiseCryptoDeps,
391*60b67249SAndroid Build Coastguard Worker {
392*60b67249SAndroid Build Coastguard Worker /// Creates a new SessionClient instance. Set up by starting and finishing a
393*60b67249SAndroid Build Coastguard Worker /// handshake.
new() -> Self394*60b67249SAndroid Build Coastguard Worker pub fn new() -> Self {
395*60b67249SAndroid Build Coastguard Worker Self {
396*60b67249SAndroid Build Coastguard Worker handshake_state: Default::default(),
397*60b67249SAndroid Build Coastguard Worker encrypt_cipher_state: Default::default(),
398*60b67249SAndroid Build Coastguard Worker decrypt_cipher_state: Default::default(),
399*60b67249SAndroid Build Coastguard Worker psk_seed: Default::default(),
400*60b67249SAndroid Build Coastguard Worker }
401*60b67249SAndroid Build Coastguard Worker }
402*60b67249SAndroid Build Coastguard Worker
403*60b67249SAndroid Build Coastguard Worker /// Starts a handshake using a known `public_key` and returns a message that
404*60b67249SAndroid Build Coastguard Worker /// works with the DPE OpenSession command.
start_handshake_with_known_public_key( &mut self, public_key: &DhPublicKey, ) -> DpeResult<HandshakeMessage>405*60b67249SAndroid Build Coastguard Worker pub fn start_handshake_with_known_public_key(
406*60b67249SAndroid Build Coastguard Worker &mut self,
407*60b67249SAndroid Build Coastguard Worker public_key: &DhPublicKey,
408*60b67249SAndroid Build Coastguard Worker ) -> DpeResult<HandshakeMessage> {
409*60b67249SAndroid Build Coastguard Worker #[allow(unused_results)]
410*60b67249SAndroid Build Coastguard Worker let mut handshake_state = {
411*60b67249SAndroid Build Coastguard Worker let mut builder = HandshakeStateBuilder::new();
412*60b67249SAndroid Build Coastguard Worker builder.set_pattern(noise_protocol::patterns::noise_nk());
413*60b67249SAndroid Build Coastguard Worker builder.set_is_initiator(true);
414*60b67249SAndroid Build Coastguard Worker builder.set_prologue(&[]);
415*60b67249SAndroid Build Coastguard Worker builder.set_rs(<D::DH as noise_protocol::DH>::Pubkey::from_slice(
416*60b67249SAndroid Build Coastguard Worker public_key.as_slice(),
417*60b67249SAndroid Build Coastguard Worker ));
418*60b67249SAndroid Build Coastguard Worker builder.build_handshake_state()
419*60b67249SAndroid Build Coastguard Worker };
420*60b67249SAndroid Build Coastguard Worker let mut message = HandshakeMessage::new();
421*60b67249SAndroid Build Coastguard Worker handshake_state.write_message(
422*60b67249SAndroid Build Coastguard Worker &[],
423*60b67249SAndroid Build Coastguard Worker message
424*60b67249SAndroid Build Coastguard Worker .as_mut_sized(handshake_state.get_next_message_overhead())?,
425*60b67249SAndroid Build Coastguard Worker )?;
426*60b67249SAndroid Build Coastguard Worker self.handshake_state = Some(handshake_state);
427*60b67249SAndroid Build Coastguard Worker Ok(message)
428*60b67249SAndroid Build Coastguard Worker }
429*60b67249SAndroid Build Coastguard Worker
430*60b67249SAndroid Build Coastguard Worker /// Starts a handshake using a `psk` and returns a message that works with
431*60b67249SAndroid Build Coastguard Worker /// the DPE DeriveContext command. Use [`derive_psk`] to obtain this value
432*60b67249SAndroid Build Coastguard Worker /// from an existing session.
433*60b67249SAndroid Build Coastguard Worker ///
434*60b67249SAndroid Build Coastguard Worker /// [`derive_psk`]: #method.derive_psk
start_handshake_with_psk( &mut self, psk: &Hash, ) -> DpeResult<HandshakeMessage>435*60b67249SAndroid Build Coastguard Worker pub fn start_handshake_with_psk(
436*60b67249SAndroid Build Coastguard Worker &mut self,
437*60b67249SAndroid Build Coastguard Worker psk: &Hash,
438*60b67249SAndroid Build Coastguard Worker ) -> DpeResult<HandshakeMessage> {
439*60b67249SAndroid Build Coastguard Worker #[allow(unused_results)]
440*60b67249SAndroid Build Coastguard Worker let mut handshake_state = {
441*60b67249SAndroid Build Coastguard Worker let mut builder = HandshakeStateBuilder::new();
442*60b67249SAndroid Build Coastguard Worker builder.set_pattern(noise_protocol::patterns::noise_nn_psk0());
443*60b67249SAndroid Build Coastguard Worker builder.set_is_initiator(true);
444*60b67249SAndroid Build Coastguard Worker builder.set_prologue(&[]);
445*60b67249SAndroid Build Coastguard Worker builder.build_handshake_state()
446*60b67249SAndroid Build Coastguard Worker };
447*60b67249SAndroid Build Coastguard Worker handshake_state
448*60b67249SAndroid Build Coastguard Worker .push_psk(psk.as_slice().get(..32).ok_or(ErrCode::InternalError)?);
449*60b67249SAndroid Build Coastguard Worker let mut message = HandshakeMessage::new();
450*60b67249SAndroid Build Coastguard Worker handshake_state.write_message(
451*60b67249SAndroid Build Coastguard Worker &[],
452*60b67249SAndroid Build Coastguard Worker message
453*60b67249SAndroid Build Coastguard Worker .as_mut_sized(handshake_state.get_next_message_overhead())?,
454*60b67249SAndroid Build Coastguard Worker )?;
455*60b67249SAndroid Build Coastguard Worker self.handshake_state = Some(handshake_state);
456*60b67249SAndroid Build Coastguard Worker Ok(message)
457*60b67249SAndroid Build Coastguard Worker }
458*60b67249SAndroid Build Coastguard Worker
459*60b67249SAndroid Build Coastguard Worker /// Finishes a handshake started using one of the start_handshake_* methods.
460*60b67249SAndroid Build Coastguard Worker /// On success, returns the handshake payload from the responder and sets up
461*60b67249SAndroid Build Coastguard Worker /// internal state for subsequent calls to encrypt and decrypt.
finish_handshake( &mut self, responder_handshake: &HandshakeMessage, ) -> DpeResult<HandshakePayload>462*60b67249SAndroid Build Coastguard Worker pub fn finish_handshake(
463*60b67249SAndroid Build Coastguard Worker &mut self,
464*60b67249SAndroid Build Coastguard Worker responder_handshake: &HandshakeMessage,
465*60b67249SAndroid Build Coastguard Worker ) -> DpeResult<HandshakePayload> {
466*60b67249SAndroid Build Coastguard Worker match self.handshake_state {
467*60b67249SAndroid Build Coastguard Worker None => Err(ErrCode::InvalidArgument),
468*60b67249SAndroid Build Coastguard Worker Some(ref mut handshake) => {
469*60b67249SAndroid Build Coastguard Worker let mut payload = HandshakePayload::new();
470*60b67249SAndroid Build Coastguard Worker handshake.read_message(
471*60b67249SAndroid Build Coastguard Worker responder_handshake.as_slice(),
472*60b67249SAndroid Build Coastguard Worker payload.as_mut_sized(
473*60b67249SAndroid Build Coastguard Worker responder_handshake.len()
474*60b67249SAndroid Build Coastguard Worker - handshake.get_next_message_overhead(),
475*60b67249SAndroid Build Coastguard Worker )?,
476*60b67249SAndroid Build Coastguard Worker )?;
477*60b67249SAndroid Build Coastguard Worker let ciphers = handshake.get_ciphers();
478*60b67249SAndroid Build Coastguard Worker self.encrypt_cipher_state = (&ciphers.0).into();
479*60b67249SAndroid Build Coastguard Worker self.decrypt_cipher_state = (&ciphers.1).into();
480*60b67249SAndroid Build Coastguard Worker self.psk_seed = Hash::from_slice(handshake.get_hash())?;
481*60b67249SAndroid Build Coastguard Worker Ok(payload)
482*60b67249SAndroid Build Coastguard Worker }
483*60b67249SAndroid Build Coastguard Worker }
484*60b67249SAndroid Build Coastguard Worker }
485*60b67249SAndroid Build Coastguard Worker
486*60b67249SAndroid Build Coastguard Worker /// Derives a PSK from the current session.
derive_psk(&self) -> Hash487*60b67249SAndroid Build Coastguard Worker pub fn derive_psk(&self) -> Hash {
488*60b67249SAndroid Build Coastguard Worker // Note this is from a client perspective so the counters are hashed
489*60b67249SAndroid Build Coastguard Worker // encrypt first and unmodified from their current state. A DPE will
490*60b67249SAndroid Build Coastguard Worker // reverse the order and decrement the first counter in order to derive
491*60b67249SAndroid Build Coastguard Worker // the same value (see derive_psk_from_session).
492*60b67249SAndroid Build Coastguard Worker let mut hasher: D::Hash = Default::default();
493*60b67249SAndroid Build Coastguard Worker hasher.input(self.psk_seed.as_slice());
494*60b67249SAndroid Build Coastguard Worker hasher.input(&self.encrypt_cipher_state.n().to_le_bytes());
495*60b67249SAndroid Build Coastguard Worker hasher.input(&self.decrypt_cipher_state.n().to_le_bytes());
496*60b67249SAndroid Build Coastguard Worker (&hasher.result()).into()
497*60b67249SAndroid Build Coastguard Worker }
498*60b67249SAndroid Build Coastguard Worker
499*60b67249SAndroid Build Coastguard Worker /// Encrypts a message to send to a DPE and commits cipher state changes.
encrypt(&mut self, in_place_buffer: &mut Message) -> DpeResult<()>500*60b67249SAndroid Build Coastguard Worker pub fn encrypt(&mut self, in_place_buffer: &mut Message) -> DpeResult<()> {
501*60b67249SAndroid Build Coastguard Worker NoiseSessionCrypto::<D>::session_encrypt(
502*60b67249SAndroid Build Coastguard Worker &mut self.encrypt_cipher_state,
503*60b67249SAndroid Build Coastguard Worker in_place_buffer,
504*60b67249SAndroid Build Coastguard Worker )?;
505*60b67249SAndroid Build Coastguard Worker self.encrypt_cipher_state.commit();
506*60b67249SAndroid Build Coastguard Worker Ok(())
507*60b67249SAndroid Build Coastguard Worker }
508*60b67249SAndroid Build Coastguard Worker
509*60b67249SAndroid Build Coastguard Worker /// Decrypts a message from a DPE.
decrypt(&mut self, in_place_buffer: &mut Message) -> DpeResult<()>510*60b67249SAndroid Build Coastguard Worker pub fn decrypt(&mut self, in_place_buffer: &mut Message) -> DpeResult<()> {
511*60b67249SAndroid Build Coastguard Worker NoiseSessionCrypto::<D>::session_decrypt(
512*60b67249SAndroid Build Coastguard Worker &mut self.decrypt_cipher_state,
513*60b67249SAndroid Build Coastguard Worker in_place_buffer,
514*60b67249SAndroid Build Coastguard Worker )
515*60b67249SAndroid Build Coastguard Worker }
516*60b67249SAndroid Build Coastguard Worker }
517*60b67249SAndroid Build Coastguard Worker
518*60b67249SAndroid Build Coastguard Worker #[cfg(test)]
519*60b67249SAndroid Build Coastguard Worker mod tests {
520*60b67249SAndroid Build Coastguard Worker use super::*;
521*60b67249SAndroid Build Coastguard Worker
522*60b67249SAndroid Build Coastguard Worker struct DepsForTesting {}
523*60b67249SAndroid Build Coastguard Worker impl NoiseCryptoDeps for DepsForTesting {
524*60b67249SAndroid Build Coastguard Worker type Cipher = noise_rust_crypto::Aes256Gcm;
525*60b67249SAndroid Build Coastguard Worker type DH = noise_rust_crypto::X25519;
526*60b67249SAndroid Build Coastguard Worker type Hash = noise_rust_crypto::Sha512;
527*60b67249SAndroid Build Coastguard Worker }
528*60b67249SAndroid Build Coastguard Worker
529*60b67249SAndroid Build Coastguard Worker type SessionCryptoForTesting = NoiseSessionCrypto<DepsForTesting>;
530*60b67249SAndroid Build Coastguard Worker
531*60b67249SAndroid Build Coastguard Worker type SessionClientForTesting = SessionClient<DepsForTesting>;
532*60b67249SAndroid Build Coastguard Worker
533*60b67249SAndroid Build Coastguard Worker type CipherStateForTesting = NoiseCipherState<noise_rust_crypto::Aes256Gcm>;
534*60b67249SAndroid Build Coastguard Worker
535*60b67249SAndroid Build Coastguard Worker #[test]
end_to_end_session()536*60b67249SAndroid Build Coastguard Worker fn end_to_end_session() {
537*60b67249SAndroid Build Coastguard Worker let mut client = SessionClientForTesting::new();
538*60b67249SAndroid Build Coastguard Worker let dh_key: DhPrivateKey = Default::default();
539*60b67249SAndroid Build Coastguard Worker let dh_public_key = get_dh_public_key::<
540*60b67249SAndroid Build Coastguard Worker <DepsForTesting as NoiseCryptoDeps>::DH,
541*60b67249SAndroid Build Coastguard Worker >(&dh_key)
542*60b67249SAndroid Build Coastguard Worker .unwrap();
543*60b67249SAndroid Build Coastguard Worker let handshake1 = client
544*60b67249SAndroid Build Coastguard Worker .start_handshake_with_known_public_key(&dh_public_key)
545*60b67249SAndroid Build Coastguard Worker .unwrap();
546*60b67249SAndroid Build Coastguard Worker let mut dpe_decrypt_cs: CipherStateForTesting = Default::default();
547*60b67249SAndroid Build Coastguard Worker let mut dpe_encrypt_cs: CipherStateForTesting = Default::default();
548*60b67249SAndroid Build Coastguard Worker let mut psk_seed = Default::default();
549*60b67249SAndroid Build Coastguard Worker let mut handshake2 = Default::default();
550*60b67249SAndroid Build Coastguard Worker let payload = HandshakePayload::from_slice("pay".as_bytes()).unwrap();
551*60b67249SAndroid Build Coastguard Worker SessionCryptoForTesting::new_session_handshake(
552*60b67249SAndroid Build Coastguard Worker &dh_key,
553*60b67249SAndroid Build Coastguard Worker &handshake1,
554*60b67249SAndroid Build Coastguard Worker &payload,
555*60b67249SAndroid Build Coastguard Worker &mut handshake2,
556*60b67249SAndroid Build Coastguard Worker &mut dpe_decrypt_cs,
557*60b67249SAndroid Build Coastguard Worker &mut dpe_encrypt_cs,
558*60b67249SAndroid Build Coastguard Worker &mut psk_seed,
559*60b67249SAndroid Build Coastguard Worker )
560*60b67249SAndroid Build Coastguard Worker .unwrap();
561*60b67249SAndroid Build Coastguard Worker assert_eq!(payload, client.finish_handshake(&handshake2).unwrap());
562*60b67249SAndroid Build Coastguard Worker
563*60b67249SAndroid Build Coastguard Worker // Check that the session works.
564*60b67249SAndroid Build Coastguard Worker let mut buffer = Message::from_slice("message".as_bytes()).unwrap();
565*60b67249SAndroid Build Coastguard Worker client.encrypt(&mut buffer).unwrap();
566*60b67249SAndroid Build Coastguard Worker SessionCryptoForTesting::session_decrypt(
567*60b67249SAndroid Build Coastguard Worker &mut dpe_decrypt_cs,
568*60b67249SAndroid Build Coastguard Worker &mut buffer,
569*60b67249SAndroid Build Coastguard Worker )
570*60b67249SAndroid Build Coastguard Worker .unwrap();
571*60b67249SAndroid Build Coastguard Worker assert_eq!("message".as_bytes(), buffer.as_slice());
572*60b67249SAndroid Build Coastguard Worker SessionCryptoForTesting::session_encrypt(
573*60b67249SAndroid Build Coastguard Worker &mut dpe_encrypt_cs,
574*60b67249SAndroid Build Coastguard Worker &mut buffer,
575*60b67249SAndroid Build Coastguard Worker )
576*60b67249SAndroid Build Coastguard Worker .unwrap();
577*60b67249SAndroid Build Coastguard Worker dpe_encrypt_cs.commit();
578*60b67249SAndroid Build Coastguard Worker client.decrypt(&mut buffer).unwrap();
579*60b67249SAndroid Build Coastguard Worker assert_eq!("message".as_bytes(), buffer.as_slice());
580*60b67249SAndroid Build Coastguard Worker
581*60b67249SAndroid Build Coastguard Worker // Do it again to check session state still works.
582*60b67249SAndroid Build Coastguard Worker client.encrypt(&mut buffer).unwrap();
583*60b67249SAndroid Build Coastguard Worker SessionCryptoForTesting::session_decrypt(
584*60b67249SAndroid Build Coastguard Worker &mut dpe_decrypt_cs,
585*60b67249SAndroid Build Coastguard Worker &mut buffer,
586*60b67249SAndroid Build Coastguard Worker )
587*60b67249SAndroid Build Coastguard Worker .unwrap();
588*60b67249SAndroid Build Coastguard Worker assert_eq!("message".as_bytes(), buffer.as_slice());
589*60b67249SAndroid Build Coastguard Worker SessionCryptoForTesting::session_encrypt(
590*60b67249SAndroid Build Coastguard Worker &mut dpe_encrypt_cs,
591*60b67249SAndroid Build Coastguard Worker &mut buffer,
592*60b67249SAndroid Build Coastguard Worker )
593*60b67249SAndroid Build Coastguard Worker .unwrap();
594*60b67249SAndroid Build Coastguard Worker dpe_encrypt_cs.commit();
595*60b67249SAndroid Build Coastguard Worker client.decrypt(&mut buffer).unwrap();
596*60b67249SAndroid Build Coastguard Worker assert_eq!("message".as_bytes(), buffer.as_slice());
597*60b67249SAndroid Build Coastguard Worker }
598*60b67249SAndroid Build Coastguard Worker
599*60b67249SAndroid Build Coastguard Worker #[test]
derived_session()600*60b67249SAndroid Build Coastguard Worker fn derived_session() {
601*60b67249SAndroid Build Coastguard Worker // Set up a session from which to derive.
602*60b67249SAndroid Build Coastguard Worker let mut client = SessionClientForTesting::new();
603*60b67249SAndroid Build Coastguard Worker let dh_key: DhPrivateKey = Default::default();
604*60b67249SAndroid Build Coastguard Worker let dh_public_key = get_dh_public_key::<
605*60b67249SAndroid Build Coastguard Worker <DepsForTesting as NoiseCryptoDeps>::DH,
606*60b67249SAndroid Build Coastguard Worker >(&dh_key)
607*60b67249SAndroid Build Coastguard Worker .unwrap();
608*60b67249SAndroid Build Coastguard Worker let handshake1 = client
609*60b67249SAndroid Build Coastguard Worker .start_handshake_with_known_public_key(&dh_public_key)
610*60b67249SAndroid Build Coastguard Worker .unwrap();
611*60b67249SAndroid Build Coastguard Worker let mut dpe_decrypt_cs = Default::default();
612*60b67249SAndroid Build Coastguard Worker let mut dpe_encrypt_cs = Default::default();
613*60b67249SAndroid Build Coastguard Worker let mut psk_seed = Default::default();
614*60b67249SAndroid Build Coastguard Worker let mut handshake2 = Default::default();
615*60b67249SAndroid Build Coastguard Worker let payload = HandshakePayload::from_slice("pay".as_bytes()).unwrap();
616*60b67249SAndroid Build Coastguard Worker SessionCryptoForTesting::new_session_handshake(
617*60b67249SAndroid Build Coastguard Worker &dh_key,
618*60b67249SAndroid Build Coastguard Worker &handshake1,
619*60b67249SAndroid Build Coastguard Worker &payload,
620*60b67249SAndroid Build Coastguard Worker &mut handshake2,
621*60b67249SAndroid Build Coastguard Worker &mut dpe_decrypt_cs,
622*60b67249SAndroid Build Coastguard Worker &mut dpe_encrypt_cs,
623*60b67249SAndroid Build Coastguard Worker &mut psk_seed,
624*60b67249SAndroid Build Coastguard Worker )
625*60b67249SAndroid Build Coastguard Worker .unwrap();
626*60b67249SAndroid Build Coastguard Worker assert_eq!(payload, client.finish_handshake(&handshake2).unwrap());
627*60b67249SAndroid Build Coastguard Worker
628*60b67249SAndroid Build Coastguard Worker // Derive a second session.
629*60b67249SAndroid Build Coastguard Worker let mut client2 = SessionClientForTesting::new();
630*60b67249SAndroid Build Coastguard Worker let client_psk = client.derive_psk();
631*60b67249SAndroid Build Coastguard Worker // Simulate the session state after command decryption on the DPE side
632*60b67249SAndroid Build Coastguard Worker // as expected by the DPE PSK logic.
633*60b67249SAndroid Build Coastguard Worker let mut buffer = Message::from_slice("message".as_bytes()).unwrap();
634*60b67249SAndroid Build Coastguard Worker client.encrypt(&mut buffer).unwrap();
635*60b67249SAndroid Build Coastguard Worker SessionCryptoForTesting::session_decrypt(
636*60b67249SAndroid Build Coastguard Worker &mut dpe_decrypt_cs,
637*60b67249SAndroid Build Coastguard Worker &mut buffer,
638*60b67249SAndroid Build Coastguard Worker )
639*60b67249SAndroid Build Coastguard Worker .unwrap();
640*60b67249SAndroid Build Coastguard Worker let dpe_psk = SessionCryptoForTesting::derive_psk_from_session(
641*60b67249SAndroid Build Coastguard Worker &psk_seed,
642*60b67249SAndroid Build Coastguard Worker &dpe_decrypt_cs,
643*60b67249SAndroid Build Coastguard Worker &dpe_encrypt_cs,
644*60b67249SAndroid Build Coastguard Worker )
645*60b67249SAndroid Build Coastguard Worker .unwrap();
646*60b67249SAndroid Build Coastguard Worker let handshake1 = client2.start_handshake_with_psk(&client_psk).unwrap();
647*60b67249SAndroid Build Coastguard Worker let mut dpe_decrypt_cs2 = Default::default();
648*60b67249SAndroid Build Coastguard Worker let mut dpe_encrypt_cs2 = Default::default();
649*60b67249SAndroid Build Coastguard Worker let mut psk_seed2 = Default::default();
650*60b67249SAndroid Build Coastguard Worker SessionCryptoForTesting::derive_session_handshake(
651*60b67249SAndroid Build Coastguard Worker &dpe_psk,
652*60b67249SAndroid Build Coastguard Worker &handshake1,
653*60b67249SAndroid Build Coastguard Worker &payload,
654*60b67249SAndroid Build Coastguard Worker &mut handshake2,
655*60b67249SAndroid Build Coastguard Worker &mut dpe_decrypt_cs2,
656*60b67249SAndroid Build Coastguard Worker &mut dpe_encrypt_cs2,
657*60b67249SAndroid Build Coastguard Worker &mut psk_seed2,
658*60b67249SAndroid Build Coastguard Worker )
659*60b67249SAndroid Build Coastguard Worker .unwrap();
660*60b67249SAndroid Build Coastguard Worker assert_eq!(payload, client2.finish_handshake(&handshake2).unwrap());
661*60b67249SAndroid Build Coastguard Worker
662*60b67249SAndroid Build Coastguard Worker // Check that the second session works.
663*60b67249SAndroid Build Coastguard Worker let mut buffer = Message::from_slice("message".as_bytes()).unwrap();
664*60b67249SAndroid Build Coastguard Worker client2.encrypt(&mut buffer).unwrap();
665*60b67249SAndroid Build Coastguard Worker SessionCryptoForTesting::session_decrypt(
666*60b67249SAndroid Build Coastguard Worker &mut dpe_decrypt_cs2,
667*60b67249SAndroid Build Coastguard Worker &mut buffer,
668*60b67249SAndroid Build Coastguard Worker )
669*60b67249SAndroid Build Coastguard Worker .unwrap();
670*60b67249SAndroid Build Coastguard Worker assert_eq!("message".as_bytes(), buffer.as_slice());
671*60b67249SAndroid Build Coastguard Worker SessionCryptoForTesting::session_encrypt(
672*60b67249SAndroid Build Coastguard Worker &mut dpe_encrypt_cs2,
673*60b67249SAndroid Build Coastguard Worker &mut buffer,
674*60b67249SAndroid Build Coastguard Worker )
675*60b67249SAndroid Build Coastguard Worker .unwrap();
676*60b67249SAndroid Build Coastguard Worker dpe_encrypt_cs2.commit();
677*60b67249SAndroid Build Coastguard Worker client2.decrypt(&mut buffer).unwrap();
678*60b67249SAndroid Build Coastguard Worker assert_eq!("message".as_bytes(), buffer.as_slice());
679*60b67249SAndroid Build Coastguard Worker
680*60b67249SAndroid Build Coastguard Worker // Check that the first session also still works.
681*60b67249SAndroid Build Coastguard Worker let mut buffer = Message::from_slice("message".as_bytes()).unwrap();
682*60b67249SAndroid Build Coastguard Worker client.encrypt(&mut buffer).unwrap();
683*60b67249SAndroid Build Coastguard Worker SessionCryptoForTesting::session_decrypt(
684*60b67249SAndroid Build Coastguard Worker &mut dpe_decrypt_cs,
685*60b67249SAndroid Build Coastguard Worker &mut buffer,
686*60b67249SAndroid Build Coastguard Worker )
687*60b67249SAndroid Build Coastguard Worker .unwrap();
688*60b67249SAndroid Build Coastguard Worker assert_eq!("message".as_bytes(), buffer.as_slice());
689*60b67249SAndroid Build Coastguard Worker SessionCryptoForTesting::session_encrypt(
690*60b67249SAndroid Build Coastguard Worker &mut dpe_encrypt_cs,
691*60b67249SAndroid Build Coastguard Worker &mut buffer,
692*60b67249SAndroid Build Coastguard Worker )
693*60b67249SAndroid Build Coastguard Worker .unwrap();
694*60b67249SAndroid Build Coastguard Worker dpe_encrypt_cs.commit();
695*60b67249SAndroid Build Coastguard Worker client.decrypt(&mut buffer).unwrap();
696*60b67249SAndroid Build Coastguard Worker assert_eq!("message".as_bytes(), buffer.as_slice());
697*60b67249SAndroid Build Coastguard Worker }
698*60b67249SAndroid Build Coastguard Worker }
699