1 #![allow(missing_docs)]
2 // Copyright 2023 Google LLC
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 use crate::d2d_connection_context_v1::D2DConnectionContextV1;
17 use crypto_provider::CryptoProvider;
18 use rand::SeedableRng as _;
19 use std::{collections::HashSet, mem};
20 use ukey2_rs::{
21     CompletedHandshake, HandshakeImplementation, NextProtocol, StateMachine, Ukey2Client,
22     Ukey2ClientStage1, Ukey2Server, Ukey2ServerStage1, Ukey2ServerStage2,
23 };
24 
25 #[derive(Debug)]
26 pub enum HandshakeError {
27     HandshakeNotComplete,
28 }
29 
30 #[derive(Debug)]
31 pub enum HandleMessageError {
32     /// The supplied message was not applicable for the current state
33     InvalidState,
34     /// Handling the message produced an error that should be sent to the other party
35     ErrorMessage(Vec<u8>),
36     /// Bad message
37     BadMessage,
38 }
39 
40 /// Implements UKEY2 and produces a [`D2DConnectionContextV1`].
41 /// This class should be kept compatible with the Java and C++ implementations in
42 /// <https://github.com/google/ukey2>.
43 ///
44 /// For usage examples, see `ukey2_shell`. This file contains a shell exercising
45 /// both the initiator and responder handshake roles.
46 pub trait D2DHandshakeContext<R = rand::rngs::StdRng>: Send
47 where
48     R: rand::RngCore + rand::CryptoRng + rand::SeedableRng + Send,
49 {
50     /// Tells the caller whether the handshake has completed or not. If the handshake is complete,
51     /// the caller may call [`to_connection_context`][Self::to_connection_context] to obtain a
52     /// connection context.
53     ///
54     /// Returns true if the handshake is complete, false otherwise.
is_handshake_complete(&self) -> bool55     fn is_handshake_complete(&self) -> bool;
56 
57     /// Constructs the next message that should be sent in the handshake.
58     ///
59     /// Returns the next message or `None` if the handshake is over.
get_next_handshake_message(&self) -> Option<Vec<u8>>60     fn get_next_handshake_message(&self) -> Option<Vec<u8>>;
61 
62     /// Parses a handshake message and advances the internal state of the context.
63     ///
64     /// * `handshakeMessage` - message received from the remote end in the handshake
handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError>65     fn handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError>;
66 
67     /// Creates a [`D2DConnectionContextV1`] using the results of the handshake. May only be called
68     /// if [`is_handshake_complete`][Self::is_handshake_complete] returns true. Before trusting the
69     /// connection, callers should check that `to_completed_handshake().auth_string()` matches on
70     /// the client and server sides first. See the documentation for
71     /// [`to_completed_handshake`][Self::to_completed_handshake].
to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError>72     fn to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError>;
73 
74     /// Returns the [`CompletedHandshake`] using the results from this handshake context. May only
75     /// be called if [`is_handshake_complete`][Self::is_handshake_complete] returns true.
76     /// Callers should verify that the authentication strings from
77     /// `to_completed_handshake().auth_string()` matches on the server and client sides before
78     /// trying to create a connection context. This authentication string verification needs to be
79     /// done out-of-band, either by displaying the string to the user, or verified by some other
80     /// secure means.
to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError>81     fn to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError>;
82 }
83 
84 enum InitiatorState<C: CryptoProvider> {
85     Stage1(Ukey2ClientStage1<C>),
86     Complete(Ukey2Client),
87     /// If the initiator enters into an invalid state, e.g. by receiving invalid input.
88     /// Also a momentary placeholder while swapping out states.
89     Invalid,
90 }
91 
92 /// Implementation of [`D2DHandshakeContext`] for the initiator (a.k.a the client).
93 pub struct InitiatorD2DHandshakeContext<C: CryptoProvider, R = rand::rngs::StdRng>
94 where
95     R: rand::RngCore + rand::CryptoRng + rand::SeedableRng + Send,
96 {
97     state: InitiatorState<C>,
98     rng: R,
99 }
100 
101 impl<C: CryptoProvider> InitiatorD2DHandshakeContext<C, rand::rngs::StdRng> {
new(handshake_impl: HandshakeImplementation, next_protocols: Vec<NextProtocol>) -> Self102     pub fn new(handshake_impl: HandshakeImplementation, next_protocols: Vec<NextProtocol>) -> Self {
103         Self::new_impl(handshake_impl, rand::rngs::StdRng::from_entropy(), next_protocols)
104     }
105 }
106 
107 impl<C: CryptoProvider, R> InitiatorD2DHandshakeContext<C, R>
108 where
109     R: rand::RngCore + rand::CryptoRng + rand::SeedableRng + Send,
110 {
111     // Used for testing / fuzzing only.
112     #[doc(hidden)]
new_impl( handshake_impl: HandshakeImplementation, mut rng: R, next_protocols: Vec<NextProtocol>, ) -> Self113     pub fn new_impl(
114         handshake_impl: HandshakeImplementation,
115         mut rng: R,
116         next_protocols: Vec<NextProtocol>,
117     ) -> Self {
118         let client = Ukey2ClientStage1::from(&mut rng, next_protocols, handshake_impl);
119         Self { state: InitiatorState::Stage1(client), rng }
120     }
121 }
122 
123 impl<C: CryptoProvider, R> D2DHandshakeContext<R> for InitiatorD2DHandshakeContext<C, R>
124 where
125     R: rand::RngCore + rand::CryptoRng + rand::SeedableRng + Send,
126 {
is_handshake_complete(&self) -> bool127     fn is_handshake_complete(&self) -> bool {
128         match self.state {
129             InitiatorState::Stage1(_) => false,
130             InitiatorState::Complete(_) => true,
131             InitiatorState::Invalid => false,
132         }
133     }
134 
get_next_handshake_message(&self) -> Option<Vec<u8>>135     fn get_next_handshake_message(&self) -> Option<Vec<u8>> {
136         let next_msg = match &self.state {
137             InitiatorState::Stage1(c) => Some(c.client_init_msg().to_vec()),
138             InitiatorState::Complete(c) => Some(c.client_finished_msg().to_vec()),
139             InitiatorState::Invalid => None,
140         }?;
141         Some(next_msg)
142     }
143 
handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError>144     fn handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError> {
145         match mem::replace(&mut self.state, InitiatorState::Invalid) {
146             InitiatorState::Stage1(c) => {
147                 let client = c
148                     .advance_state(&mut self.rng, message)
149                     .map_err(|a| HandleMessageError::ErrorMessage(a.into_wrapped_alert_msg()))?;
150                 self.state = InitiatorState::Complete(client);
151                 Ok(())
152             }
153             InitiatorState::Complete(_) | InitiatorState::Invalid => {
154                 // already in invalid state, so leave it as is
155                 Err(HandleMessageError::InvalidState)
156             }
157         }
158     }
159 
to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError>160     fn to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError> {
161         // Since self.rng is expected to be a seeded PRNG, not an OsRng directly, from_rng
162         // should never fail. https://rust-random.github.io/book/guide-err.html
163         let rng = R::from_rng(&mut self.rng).unwrap();
164         self.to_completed_handshake()
165             .map(|h| D2DConnectionContextV1::from_initiator_handshake::<C>(h, rng))
166     }
167 
to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError>168     fn to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError> {
169         match &self.state {
170             InitiatorState::Stage1(_) | InitiatorState::Invalid => {
171                 Err(HandshakeError::HandshakeNotComplete)
172             }
173             InitiatorState::Complete(c) => Ok(c.completed_handshake()),
174         }
175     }
176 }
177 
178 enum ServerState<C: CryptoProvider> {
179     Stage1(Ukey2ServerStage1<C>),
180     Stage2(Ukey2ServerStage2<C>),
181     Complete(Ukey2Server),
182     /// If the initiator enters into an invalid state, e.g. by receiving invalid input.
183     /// Also a momentary placeholder while swapping out states.
184     Invalid,
185 }
186 
187 /// Implementation of [`D2DHandshakeContext`] for the server.
188 pub struct ServerD2DHandshakeContext<C: CryptoProvider, R = rand::rngs::StdRng>
189 where
190     R: rand::Rng + rand::SeedableRng + rand::CryptoRng + Send,
191 {
192     state: ServerState<C>,
193     rng: R,
194 }
195 
196 impl<C: CryptoProvider> ServerD2DHandshakeContext<C, rand::rngs::StdRng> {
new(handshake_impl: HandshakeImplementation, next_protocols: &[NextProtocol]) -> Self197     pub fn new(handshake_impl: HandshakeImplementation, next_protocols: &[NextProtocol]) -> Self {
198         Self::new_impl(handshake_impl, rand::rngs::StdRng::from_entropy(), next_protocols)
199     }
200 }
201 
202 impl<C: CryptoProvider, R> ServerD2DHandshakeContext<C, R>
203 where
204     R: rand::Rng + rand::SeedableRng + rand::CryptoRng + Send,
205 {
206     // Used for testing / fuzzing only.
207     #[doc(hidden)]
new_impl( handshake_impl: HandshakeImplementation, rng: R, next_protocols: &[NextProtocol], ) -> Self208     pub fn new_impl(
209         handshake_impl: HandshakeImplementation,
210         rng: R,
211         next_protocols: &[NextProtocol],
212     ) -> Self {
213         Self {
214             state: ServerState::Stage1(Ukey2ServerStage1::from(
215                 HashSet::from_iter(next_protocols.iter().map(|np| np.to_string())),
216                 handshake_impl,
217             )),
218             rng,
219         }
220     }
221 }
222 
223 impl<C, R> D2DHandshakeContext<R> for ServerD2DHandshakeContext<C, R>
224 where
225     C: CryptoProvider,
226     R: rand::Rng + rand::SeedableRng + rand::CryptoRng + Send,
227 {
is_handshake_complete(&self) -> bool228     fn is_handshake_complete(&self) -> bool {
229         match &self.state {
230             ServerState::Complete(_) => true,
231             ServerState::Stage1(_) | ServerState::Stage2(_) | ServerState::Invalid => false,
232         }
233     }
234 
get_next_handshake_message(&self) -> Option<Vec<u8>>235     fn get_next_handshake_message(&self) -> Option<Vec<u8>> {
236         let next_msg = match &self.state {
237             ServerState::Stage1(_) => None,
238             ServerState::Stage2(s) => Some(s.server_init_msg().to_vec()),
239             ServerState::Complete(_) => None,
240             ServerState::Invalid => None,
241         }?;
242         Some(next_msg)
243     }
244 
handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError>245     fn handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError> {
246         match mem::replace(&mut self.state, ServerState::Invalid) {
247             ServerState::Stage1(s) => {
248                 let server2 = s
249                     .advance_state(&mut self.rng, message)
250                     .map_err(|a| HandleMessageError::ErrorMessage(a.into_wrapped_alert_msg()))?;
251                 self.state = ServerState::Stage2(server2);
252                 Ok(())
253             }
254             ServerState::Stage2(s) => {
255                 let server = s
256                     .advance_state(&mut self.rng, message)
257                     .map_err(|a| HandleMessageError::ErrorMessage(a.into_wrapped_alert_msg()))?;
258                 self.state = ServerState::Complete(server);
259                 Ok(())
260             }
261             ServerState::Complete(_) | ServerState::Invalid => {
262                 Err(HandleMessageError::InvalidState)
263             }
264         }
265     }
266 
to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError>267     fn to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError> {
268         match &self.state {
269             ServerState::Stage1(_) | ServerState::Stage2(_) | ServerState::Invalid => {
270                 Err(HandshakeError::HandshakeNotComplete)
271             }
272             ServerState::Complete(s) => Ok(s.completed_handshake()),
273         }
274     }
275 
to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError>276     fn to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError> {
277         // Since self.rng is expected to be a seeded PRNG, not an OsRng directly, from_rng
278         // should never fail. https://rust-random.github.io/book/guide-err.html
279         let rng = R::from_rng(&mut self.rng).unwrap();
280         self.to_completed_handshake()
281             .map(|h| D2DConnectionContextV1::from_responder_handshake::<C>(h, rng))
282     }
283 }
284