1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use std::collections::HashMap;
16 use std::ptr::null_mut;
17 
18 use crypto_provider_default::CryptoProviderImpl as CryptoProvider;
19 use lazy_static::lazy_static;
20 use lock_adapter::NoPoisonMutex;
21 use rand::Rng;
22 use rand_chacha::rand_core::SeedableRng;
23 use rand_chacha::ChaCha20Rng;
24 
25 #[cfg(not(feature = "std"))]
26 use lock_adapter::spin::Mutex;
27 #[cfg(feature = "std")]
28 use lock_adapter::stdlib::Mutex;
29 
30 use ukey2_connections::{
31     D2DConnectionContextV1, D2DHandshakeContext, HandleMessageError, HandshakeImplementation,
32     InitiatorD2DHandshakeContext, NextProtocol, ServerD2DHandshakeContext,
33 };
34 
35 #[repr(C)]
36 pub struct RustFFIByteArray {
37     ptr: *mut u8,
38     len: usize,
39     cap: usize,
40 }
41 
42 impl RustFFIByteArray {
from_vec(vec: Vec<u8>) -> RustFFIByteArray43     fn from_vec(vec: Vec<u8>) -> RustFFIByteArray {
44         let mut vec = core::mem::ManuallyDrop::new(vec);
45         RustFFIByteArray { ptr: vec.as_mut_ptr(), len: vec.len(), cap: vec.capacity() }
46     }
47 
into_vec(self) -> Option<Vec<u8>>48     unsafe fn into_vec(self) -> Option<Vec<u8>> {
49         if self.ptr.is_null() {
50             return None;
51         }
52         Some(Vec::from_raw_parts(self.ptr, self.len, self.cap))
53     }
54 }
55 
56 #[repr(C)]
57 pub struct CFFIByteArray {
58     ptr: *mut u8,
59     len: usize,
60 }
61 
62 #[repr(C)]
63 pub struct CMessageParseResult {
64     success: bool,
65     alert_to_send: RustFFIByteArray,
66 }
67 
68 type D2DBox = Box<dyn D2DHandshakeContext>;
69 type ConnectionBox = Box<D2DConnectionContextV1>;
70 
71 lazy_static! {
72     static ref HANDLE_MAPPING: Mutex<HashMap<u64, D2DBox>> = Mutex::new(HashMap::new());
73     static ref CONNECTION_HANDLE_MAPPING: Mutex<HashMap<u64, ConnectionBox>> =
74         Mutex::new(HashMap::new());
75     static ref RNG: Mutex<ChaCha20Rng> = Mutex::new(ChaCha20Rng::from_entropy());
76 }
77 
generate_handle() -> u6478 fn generate_handle() -> u64 {
79     RNG.lock().gen()
80 }
81 
insert_gen_handle(item: D2DBox) -> u6482 fn insert_gen_handle(item: D2DBox) -> u64 {
83     let handle = generate_handle();
84     HANDLE_MAPPING.lock().insert(handle, item);
85     handle
86 }
87 
insert_conn_gen_handle(item: ConnectionBox) -> u6488 fn insert_conn_gen_handle(item: ConnectionBox) -> u64 {
89     let handle = generate_handle();
90     CONNECTION_HANDLE_MAPPING.lock().insert(handle, item);
91     handle
92 }
93 
94 // Utilities
95 /// This function deallocates FFIByteArray instances allocated from Rust only.
96 /// NOTE: Any FFIByteArray instances deallocated by this function will no longer be in a guaranteed
97 /// usable state.
98 ///
99 /// # Safety
100 /// The array must have been allocated by a Rust function with the Rust allocator, e.g.
101 /// [get_next_handshake_message].
102 #[no_mangle]
rust_dealloc_ffi_byte_array(arr: RustFFIByteArray)103 pub unsafe extern "C" fn rust_dealloc_ffi_byte_array(arr: RustFFIByteArray) {
104     if let Some(vec) = arr.into_vec() {
105         core::mem::drop(vec);
106     }
107 }
108 
109 // Common functions
110 #[no_mangle]
is_handshake_complete(handle: u64) -> bool111 pub extern "C" fn is_handshake_complete(handle: u64) -> bool {
112     HANDLE_MAPPING.lock().get(&handle).map_or(false, |ctx| ctx.is_handshake_complete())
113 }
114 
115 #[no_mangle]
get_next_handshake_message(handle: u64) -> RustFFIByteArray116 pub extern "C" fn get_next_handshake_message(handle: u64) -> RustFFIByteArray {
117     // TODO: error handling
118     let opt_msg = HANDLE_MAPPING.lock().get(&handle).and_then(|c| c.get_next_handshake_message());
119     if let Some(msg) = opt_msg {
120         RustFFIByteArray::from_vec(msg)
121     } else {
122         RustFFIByteArray { ptr: null_mut(), len: usize::MAX, cap: usize::MAX }
123     }
124 }
125 
126 /// # Safety
127 /// We treat msg as data, so we should never have an issue trying to execute it.
128 #[no_mangle]
parse_handshake_message( handle: u64, arr: CFFIByteArray, ) -> CMessageParseResult129 pub unsafe extern "C" fn parse_handshake_message(
130     handle: u64,
131     arr: CFFIByteArray,
132 ) -> CMessageParseResult {
133     let msg = std::slice::from_raw_parts(arr.ptr, arr.len);
134     let result = HANDLE_MAPPING.lock().get_mut(&handle).unwrap().handle_handshake_message(msg);
135     if let Err(error) = result {
136         match error {
137             HandleMessageError::InvalidState | HandleMessageError::BadMessage => {
138                 log::error!("{:?}", error);
139             }
140             HandleMessageError::ErrorMessage(message) => {
141                 return CMessageParseResult {
142                     success: false,
143                     alert_to_send: RustFFIByteArray::from_vec(message),
144                 };
145             }
146         }
147     }
148     CMessageParseResult {
149         success: true,
150         alert_to_send: RustFFIByteArray { ptr: null_mut(), len: usize::MAX, cap: usize::MAX },
151     }
152 }
153 
154 #[no_mangle]
get_verification_string(handle: u64, length: usize) -> RustFFIByteArray155 pub extern "C" fn get_verification_string(handle: u64, length: usize) -> RustFFIByteArray {
156     HANDLE_MAPPING
157         .lock()
158         .get(&handle)
159         .map(|h| {
160             let auth_vec = h
161                 .to_completed_handshake()
162                 .unwrap()
163                 .auth_string::<CryptoProvider>()
164                 .derive_vec(length)
165                 .unwrap();
166             RustFFIByteArray::from_vec(auth_vec)
167         })
168         .unwrap()
169 }
170 
171 #[no_mangle]
to_connection_context(handle: u64) -> u64172 pub extern "C" fn to_connection_context(handle: u64) -> u64 {
173     // TODO: error handling
174     let ctx = HANDLE_MAPPING
175         .lock()
176         .remove(&handle)
177         .map(move |mut ctx| {
178             let result = Box::new(ctx.to_connection_context().unwrap());
179             drop(ctx);
180             result
181         })
182         .unwrap();
183     insert_conn_gen_handle(ctx)
184 }
185 
186 // Responder-specific functions
187 #[no_mangle]
responder_new() -> u64188 pub extern "C" fn responder_new() -> u64 {
189     let ctx = Box::new(ServerD2DHandshakeContext::<CryptoProvider>::new(
190         HandshakeImplementation::PublicKeyInProtobuf,
191         &[NextProtocol::Aes256CbcHmacSha256],
192     ));
193     insert_gen_handle(ctx)
194 }
195 
196 // Initiator-specific functions
197 
198 /// # Safety
199 /// We treat next_protocol as data, not as executable memory.
200 #[no_mangle]
initiator_new() -> u64201 pub extern "C" fn initiator_new() -> u64 {
202     let ctx = Box::new(InitiatorD2DHandshakeContext::<CryptoProvider>::new(
203         HandshakeImplementation::PublicKeyInProtobuf,
204         vec![NextProtocol::Aes256CbcHmacSha256],
205     ));
206     insert_gen_handle(ctx)
207 }
208 
209 // Connection Context
210 
211 /// # Safety
212 /// We treat msg and associated_data as data, not as executable memory.
213 /// associated_data and msg are slices so Rust won't try to do anything weird with allocation.
214 #[no_mangle]
encode_message_to_peer( handle: u64, msg: CFFIByteArray, associated_data: CFFIByteArray, ) -> RustFFIByteArray215 pub unsafe extern "C" fn encode_message_to_peer(
216     handle: u64,
217     msg: CFFIByteArray,
218     associated_data: CFFIByteArray,
219 ) -> RustFFIByteArray {
220     if msg.len == 0 {
221         return RustFFIByteArray { ptr: null_mut(), len: usize::MAX, cap: usize::MAX };
222     }
223     let msg = std::slice::from_raw_parts(msg.ptr, msg.len);
224     let associated_data = if !associated_data.ptr.is_null() {
225         Some(std::slice::from_raw_parts(associated_data.ptr, associated_data.len))
226     } else {
227         None
228     };
229     let ret = CONNECTION_HANDLE_MAPPING
230         .lock()
231         .get_mut(&handle)
232         .map(|c| c.encode_message_to_peer::<CryptoProvider, _>(msg, associated_data));
233     if let Some(msg) = ret {
234         RustFFIByteArray::from_vec(msg)
235     } else {
236         log::error!("Was unable to find handle!");
237         RustFFIByteArray { ptr: null_mut(), len: usize::MAX, cap: usize::MAX }
238     }
239 }
240 
241 /// # Safety
242 /// We treat msg as data, not as executable memory.
243 #[no_mangle]
decode_message_from_peer( handle: u64, msg: CFFIByteArray, associated_data: CFFIByteArray, ) -> RustFFIByteArray244 pub unsafe extern "C" fn decode_message_from_peer(
245     handle: u64,
246     msg: CFFIByteArray,
247     associated_data: CFFIByteArray,
248 ) -> RustFFIByteArray {
249     if msg.len == 0 {
250         return RustFFIByteArray { ptr: null_mut(), len: usize::MAX, cap: usize::MAX };
251     }
252     let msg = std::slice::from_raw_parts(msg.ptr, msg.len);
253     let associated_data = if !associated_data.ptr.is_null() {
254         Some(std::slice::from_raw_parts(associated_data.ptr, associated_data.len))
255     } else {
256         None
257     };
258     let ret: Result<Vec<u8>, ukey2_connections::DecodeError> = CONNECTION_HANDLE_MAPPING
259         .lock()
260         .get_mut(&handle)
261         .unwrap()
262         .decode_message_from_peer::<CryptoProvider, _>(msg, associated_data);
263     if let Ok(decoded) = ret {
264         RustFFIByteArray::from_vec(decoded)
265     } else {
266         RustFFIByteArray { ptr: null_mut(), len: usize::MAX, cap: usize::MAX }
267     }
268 }
269 
270 #[no_mangle]
get_session_unique(handle: u64) -> RustFFIByteArray271 pub extern "C" fn get_session_unique(handle: u64) -> RustFFIByteArray {
272     let session_unique_bytes = CONNECTION_HANDLE_MAPPING
273         .lock()
274         .get(&handle)
275         .unwrap()
276         .get_session_unique::<CryptoProvider>();
277     RustFFIByteArray::from_vec(session_unique_bytes)
278 }
279 
280 #[no_mangle]
get_sequence_number_for_encoding(handle: u64) -> i32281 pub extern "C" fn get_sequence_number_for_encoding(handle: u64) -> i32 {
282     CONNECTION_HANDLE_MAPPING.lock().get(&handle).unwrap().get_sequence_number_for_encoding()
283 }
284 
285 #[no_mangle]
get_sequence_number_for_decoding(handle: u64) -> i32286 pub extern "C" fn get_sequence_number_for_decoding(handle: u64) -> i32 {
287     CONNECTION_HANDLE_MAPPING.lock().get(&handle).unwrap().get_sequence_number_for_decoding()
288 }
289 
290 #[no_mangle]
save_session(handle: u64) -> RustFFIByteArray291 pub extern "C" fn save_session(handle: u64) -> RustFFIByteArray {
292     let key = CONNECTION_HANDLE_MAPPING.lock().get(&handle).unwrap().save_session();
293     RustFFIByteArray::from_vec(key)
294 }
295 
296 #[repr(i32)]
297 #[derive(Debug)]
298 pub enum Status {
299     Good,
300     Error,
301 }
302 
303 #[repr(C)]
304 pub struct CD2DRestoreConnectionContextV1Result {
305     handle: u64,
306     status: Status,
307 }
308 
309 /// # Safety
310 /// We error out if the length is incorrect (too large or too small) for restoring a session.
311 #[no_mangle]
from_saved_session( arr: CFFIByteArray, ) -> CD2DRestoreConnectionContextV1Result312 pub unsafe extern "C" fn from_saved_session(
313     arr: CFFIByteArray,
314 ) -> CD2DRestoreConnectionContextV1Result {
315     let saved_session = std::slice::from_raw_parts(arr.ptr, arr.len);
316     let ctx = D2DConnectionContextV1::from_saved_session::<CryptoProvider>(saved_session);
317     if let Ok(conn_ctx) = ctx {
318         let final_ctx = Box::new(conn_ctx);
319         CD2DRestoreConnectionContextV1Result {
320             handle: insert_conn_gen_handle(final_ctx),
321             status: Status::Good,
322         }
323     } else {
324         log::error!("failed to restore session with error {:?}", ctx.unwrap_err());
325         CD2DRestoreConnectionContextV1Result { handle: u64::MAX, status: Status::Error }
326     }
327 }
328