1 #![allow(missing_docs)] 2 // Copyright 2023 Google LLC 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 use crate::d2d_connection_context_v1::D2DConnectionContextV1; 17 use crypto_provider::CryptoProvider; 18 use rand::SeedableRng as _; 19 use std::{collections::HashSet, mem}; 20 use ukey2_rs::{ 21 CompletedHandshake, HandshakeImplementation, NextProtocol, StateMachine, Ukey2Client, 22 Ukey2ClientStage1, Ukey2Server, Ukey2ServerStage1, Ukey2ServerStage2, 23 }; 24 25 #[derive(Debug)] 26 pub enum HandshakeError { 27 HandshakeNotComplete, 28 } 29 30 #[derive(Debug)] 31 pub enum HandleMessageError { 32 /// The supplied message was not applicable for the current state 33 InvalidState, 34 /// Handling the message produced an error that should be sent to the other party 35 ErrorMessage(Vec<u8>), 36 /// Bad message 37 BadMessage, 38 } 39 40 /// Implements UKEY2 and produces a [`D2DConnectionContextV1`]. 41 /// This class should be kept compatible with the Java and C++ implementations in 42 /// <https://github.com/google/ukey2>. 43 /// 44 /// For usage examples, see `ukey2_shell`. This file contains a shell exercising 45 /// both the initiator and responder handshake roles. 46 pub trait D2DHandshakeContext<R = rand::rngs::StdRng>: Send 47 where 48 R: rand::RngCore + rand::CryptoRng + rand::SeedableRng + Send, 49 { 50 /// Tells the caller whether the handshake has completed or not. If the handshake is complete, 51 /// the caller may call [`to_connection_context`][Self::to_connection_context] to obtain a 52 /// connection context. 53 /// 54 /// Returns true if the handshake is complete, false otherwise. is_handshake_complete(&self) -> bool55 fn is_handshake_complete(&self) -> bool; 56 57 /// Constructs the next message that should be sent in the handshake. 58 /// 59 /// Returns the next message or `None` if the handshake is over. get_next_handshake_message(&self) -> Option<Vec<u8>>60 fn get_next_handshake_message(&self) -> Option<Vec<u8>>; 61 62 /// Parses a handshake message and advances the internal state of the context. 63 /// 64 /// * `handshakeMessage` - message received from the remote end in the handshake handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError>65 fn handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError>; 66 67 /// Creates a [`D2DConnectionContextV1`] using the results of the handshake. May only be called 68 /// if [`is_handshake_complete`][Self::is_handshake_complete] returns true. Before trusting the 69 /// connection, callers should check that `to_completed_handshake().auth_string()` matches on 70 /// the client and server sides first. See the documentation for 71 /// [`to_completed_handshake`][Self::to_completed_handshake]. to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError>72 fn to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError>; 73 74 /// Returns the [`CompletedHandshake`] using the results from this handshake context. May only 75 /// be called if [`is_handshake_complete`][Self::is_handshake_complete] returns true. 76 /// Callers should verify that the authentication strings from 77 /// `to_completed_handshake().auth_string()` matches on the server and client sides before 78 /// trying to create a connection context. This authentication string verification needs to be 79 /// done out-of-band, either by displaying the string to the user, or verified by some other 80 /// secure means. to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError>81 fn to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError>; 82 } 83 84 enum InitiatorState<C: CryptoProvider> { 85 Stage1(Ukey2ClientStage1<C>), 86 Complete(Ukey2Client), 87 /// If the initiator enters into an invalid state, e.g. by receiving invalid input. 88 /// Also a momentary placeholder while swapping out states. 89 Invalid, 90 } 91 92 /// Implementation of [`D2DHandshakeContext`] for the initiator (a.k.a the client). 93 pub struct InitiatorD2DHandshakeContext<C: CryptoProvider, R = rand::rngs::StdRng> 94 where 95 R: rand::RngCore + rand::CryptoRng + rand::SeedableRng + Send, 96 { 97 state: InitiatorState<C>, 98 rng: R, 99 } 100 101 impl<C: CryptoProvider> InitiatorD2DHandshakeContext<C, rand::rngs::StdRng> { new(handshake_impl: HandshakeImplementation, next_protocols: Vec<NextProtocol>) -> Self102 pub fn new(handshake_impl: HandshakeImplementation, next_protocols: Vec<NextProtocol>) -> Self { 103 Self::new_impl(handshake_impl, rand::rngs::StdRng::from_entropy(), next_protocols) 104 } 105 } 106 107 impl<C: CryptoProvider, R> InitiatorD2DHandshakeContext<C, R> 108 where 109 R: rand::RngCore + rand::CryptoRng + rand::SeedableRng + Send, 110 { 111 // Used for testing / fuzzing only. 112 #[doc(hidden)] new_impl( handshake_impl: HandshakeImplementation, mut rng: R, next_protocols: Vec<NextProtocol>, ) -> Self113 pub fn new_impl( 114 handshake_impl: HandshakeImplementation, 115 mut rng: R, 116 next_protocols: Vec<NextProtocol>, 117 ) -> Self { 118 let client = Ukey2ClientStage1::from(&mut rng, next_protocols, handshake_impl); 119 Self { state: InitiatorState::Stage1(client), rng } 120 } 121 } 122 123 impl<C: CryptoProvider, R> D2DHandshakeContext<R> for InitiatorD2DHandshakeContext<C, R> 124 where 125 R: rand::RngCore + rand::CryptoRng + rand::SeedableRng + Send, 126 { is_handshake_complete(&self) -> bool127 fn is_handshake_complete(&self) -> bool { 128 match self.state { 129 InitiatorState::Stage1(_) => false, 130 InitiatorState::Complete(_) => true, 131 InitiatorState::Invalid => false, 132 } 133 } 134 get_next_handshake_message(&self) -> Option<Vec<u8>>135 fn get_next_handshake_message(&self) -> Option<Vec<u8>> { 136 let next_msg = match &self.state { 137 InitiatorState::Stage1(c) => Some(c.client_init_msg().to_vec()), 138 InitiatorState::Complete(c) => Some(c.client_finished_msg().to_vec()), 139 InitiatorState::Invalid => None, 140 }?; 141 Some(next_msg) 142 } 143 handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError>144 fn handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError> { 145 match mem::replace(&mut self.state, InitiatorState::Invalid) { 146 InitiatorState::Stage1(c) => { 147 let client = c 148 .advance_state(&mut self.rng, message) 149 .map_err(|a| HandleMessageError::ErrorMessage(a.into_wrapped_alert_msg()))?; 150 self.state = InitiatorState::Complete(client); 151 Ok(()) 152 } 153 InitiatorState::Complete(_) | InitiatorState::Invalid => { 154 // already in invalid state, so leave it as is 155 Err(HandleMessageError::InvalidState) 156 } 157 } 158 } 159 to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError>160 fn to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError> { 161 // Since self.rng is expected to be a seeded PRNG, not an OsRng directly, from_rng 162 // should never fail. https://rust-random.github.io/book/guide-err.html 163 let rng = R::from_rng(&mut self.rng).unwrap(); 164 self.to_completed_handshake() 165 .map(|h| D2DConnectionContextV1::from_initiator_handshake::<C>(h, rng)) 166 } 167 to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError>168 fn to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError> { 169 match &self.state { 170 InitiatorState::Stage1(_) | InitiatorState::Invalid => { 171 Err(HandshakeError::HandshakeNotComplete) 172 } 173 InitiatorState::Complete(c) => Ok(c.completed_handshake()), 174 } 175 } 176 } 177 178 enum ServerState<C: CryptoProvider> { 179 Stage1(Ukey2ServerStage1<C>), 180 Stage2(Ukey2ServerStage2<C>), 181 Complete(Ukey2Server), 182 /// If the initiator enters into an invalid state, e.g. by receiving invalid input. 183 /// Also a momentary placeholder while swapping out states. 184 Invalid, 185 } 186 187 /// Implementation of [`D2DHandshakeContext`] for the server. 188 pub struct ServerD2DHandshakeContext<C: CryptoProvider, R = rand::rngs::StdRng> 189 where 190 R: rand::Rng + rand::SeedableRng + rand::CryptoRng + Send, 191 { 192 state: ServerState<C>, 193 rng: R, 194 } 195 196 impl<C: CryptoProvider> ServerD2DHandshakeContext<C, rand::rngs::StdRng> { new(handshake_impl: HandshakeImplementation, next_protocols: &[NextProtocol]) -> Self197 pub fn new(handshake_impl: HandshakeImplementation, next_protocols: &[NextProtocol]) -> Self { 198 Self::new_impl(handshake_impl, rand::rngs::StdRng::from_entropy(), next_protocols) 199 } 200 } 201 202 impl<C: CryptoProvider, R> ServerD2DHandshakeContext<C, R> 203 where 204 R: rand::Rng + rand::SeedableRng + rand::CryptoRng + Send, 205 { 206 // Used for testing / fuzzing only. 207 #[doc(hidden)] new_impl( handshake_impl: HandshakeImplementation, rng: R, next_protocols: &[NextProtocol], ) -> Self208 pub fn new_impl( 209 handshake_impl: HandshakeImplementation, 210 rng: R, 211 next_protocols: &[NextProtocol], 212 ) -> Self { 213 Self { 214 state: ServerState::Stage1(Ukey2ServerStage1::from( 215 HashSet::from_iter(next_protocols.iter().map(|np| np.to_string())), 216 handshake_impl, 217 )), 218 rng, 219 } 220 } 221 } 222 223 impl<C, R> D2DHandshakeContext<R> for ServerD2DHandshakeContext<C, R> 224 where 225 C: CryptoProvider, 226 R: rand::Rng + rand::SeedableRng + rand::CryptoRng + Send, 227 { is_handshake_complete(&self) -> bool228 fn is_handshake_complete(&self) -> bool { 229 match &self.state { 230 ServerState::Complete(_) => true, 231 ServerState::Stage1(_) | ServerState::Stage2(_) | ServerState::Invalid => false, 232 } 233 } 234 get_next_handshake_message(&self) -> Option<Vec<u8>>235 fn get_next_handshake_message(&self) -> Option<Vec<u8>> { 236 let next_msg = match &self.state { 237 ServerState::Stage1(_) => None, 238 ServerState::Stage2(s) => Some(s.server_init_msg().to_vec()), 239 ServerState::Complete(_) => None, 240 ServerState::Invalid => None, 241 }?; 242 Some(next_msg) 243 } 244 handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError>245 fn handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError> { 246 match mem::replace(&mut self.state, ServerState::Invalid) { 247 ServerState::Stage1(s) => { 248 let server2 = s 249 .advance_state(&mut self.rng, message) 250 .map_err(|a| HandleMessageError::ErrorMessage(a.into_wrapped_alert_msg()))?; 251 self.state = ServerState::Stage2(server2); 252 Ok(()) 253 } 254 ServerState::Stage2(s) => { 255 let server = s 256 .advance_state(&mut self.rng, message) 257 .map_err(|a| HandleMessageError::ErrorMessage(a.into_wrapped_alert_msg()))?; 258 self.state = ServerState::Complete(server); 259 Ok(()) 260 } 261 ServerState::Complete(_) | ServerState::Invalid => { 262 Err(HandleMessageError::InvalidState) 263 } 264 } 265 } 266 to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError>267 fn to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError> { 268 match &self.state { 269 ServerState::Stage1(_) | ServerState::Stage2(_) | ServerState::Invalid => { 270 Err(HandshakeError::HandshakeNotComplete) 271 } 272 ServerState::Complete(s) => Ok(s.completed_handshake()), 273 } 274 } 275 to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError>276 fn to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError> { 277 // Since self.rng is expected to be a seeded PRNG, not an OsRng directly, from_rng 278 // should never fail. https://rust-random.github.io/book/guide-err.html 279 let rng = R::from_rng(&mut self.rng).unwrap(); 280 self.to_completed_handshake() 281 .map(|h| D2DConnectionContextV1::from_responder_handshake::<C>(h, rng)) 282 } 283 } 284