xref: /aosp_15_r20/external/crosvm/win_util/src/dpapi.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2024 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #![deny(unsafe_op_in_unsafe_fn)]
6 
7 //! Safe, Rusty wrappers around DPAPI.
8 
9 use std::ffi::c_void;
10 use std::ptr;
11 use std::slice;
12 
13 use anyhow::Context;
14 use anyhow::Result;
15 use winapi::um::dpapi::CryptProtectData;
16 use winapi::um::dpapi::CryptUnprotectData;
17 use winapi::um::winbase::LocalFree;
18 use winapi::um::wincrypt::DATA_BLOB;
19 use zeroize::Zeroize;
20 
21 use crate::syscall_bail;
22 
23 /// Wrapper around buffers allocated by DPAPI that can be freed with LocalFree.
24 pub struct LocalAllocBuffer {
25     ptr: *mut u8,
26     len: usize,
27 }
28 
29 impl LocalAllocBuffer {
30     /// # Safety
31     /// 0. ptr is a valid buffer of length len and is safe to free with LocalFree.
32     /// 1. The caller transfers ownership of the buffer to this object on construction.
new(ptr: *mut u8, len: usize) -> Self33     unsafe fn new(ptr: *mut u8, len: usize) -> Self {
34         Self { ptr, len }
35     }
36 
as_mut_slice(&mut self) -> &mut [u8]37     pub fn as_mut_slice(&mut self) -> &mut [u8] {
38         // SAFETY: ptr is a pointer to a buffer of length len.
39         unsafe { slice::from_raw_parts_mut(self.ptr, self.len) }
40     }
41 
as_slice(&self) -> &[u8]42     pub fn as_slice(&self) -> &[u8] {
43         // SAFETY: ptr is a pointer to a buffer of length len.
44         unsafe { slice::from_raw_parts(self.ptr, self.len) }
45     }
46 }
47 
48 impl Drop for LocalAllocBuffer {
drop(&mut self)49     fn drop(&mut self) {
50         // This buffer likely contains cryptographic key material. Zero it.
51         self.as_mut_slice().zeroize();
52 
53         // SAFETY: when this struct is created, the caller guarantees
54         // ptr is a valid pointer to a buffer that can be freed with LocalFree.
55         unsafe {
56             LocalFree(self.ptr as *mut c_void);
57         }
58     }
59 }
60 
61 /// # Summary
62 /// Wrapper around CryptProtectData that displays no UI.
crypt_protect_data(plaintext: &mut [u8]) -> Result<LocalAllocBuffer>63 pub fn crypt_protect_data(plaintext: &mut [u8]) -> Result<LocalAllocBuffer> {
64     let mut plaintext_blob = DATA_BLOB {
65         cbData: plaintext
66             .len()
67             .try_into()
68             .context("plaintext size won't fit in DWORD")?,
69         pbData: plaintext.as_mut_ptr(),
70     };
71     let mut ciphertext_blob = DATA_BLOB {
72         cbData: 0,
73         pbData: ptr::null_mut(),
74     };
75 
76     // SAFETY: the FFI call is safe because
77     // 1. plaintext_blob lives longer than the call.
78     // 2. ciphertext_blob lives longer than the call, and we later give ownership of the memory the
79     //    kernel allocates to LocalAllocBuffer which guarantees it is freed.
80     let res = unsafe {
81         CryptProtectData(
82             &mut plaintext_blob as *mut _,
83             /* szDataDescr= */ ptr::null_mut(),
84             /* pOptionalEntropy= */ ptr::null_mut(),
85             /* pvReserved= */ ptr::null_mut(),
86             /* pPromptStruct */ ptr::null_mut(),
87             /* dwFlags */ 0,
88             &mut ciphertext_blob as *mut _,
89         )
90     };
91     if res == 0 {
92         syscall_bail!("CryptProtectData failed");
93     }
94 
95     let ciphertext_len: usize = ciphertext_blob
96         .cbData
97         .try_into()
98         .context("resulting ciphertext had an invalid size")?;
99 
100     // SAFETY: safe because ciphertext_blob refers to a valid buffer of the specified length. This
101     // is guaranteed because CryptProtectData returned success.
102     Ok(unsafe { LocalAllocBuffer::new(ciphertext_blob.pbData, ciphertext_len) })
103 }
104 
105 /// # Summary
106 /// Wrapper around CryptProtectData that displays no UI.
crypt_unprotect_data(ciphertext: &mut [u8]) -> Result<LocalAllocBuffer>107 pub fn crypt_unprotect_data(ciphertext: &mut [u8]) -> Result<LocalAllocBuffer> {
108     let mut ciphertext_blob = DATA_BLOB {
109         cbData: ciphertext
110             .len()
111             .try_into()
112             .context("plaintext size won't fit in DWORD")?,
113         pbData: ciphertext.as_mut_ptr(),
114     };
115     let mut plaintext_blob = DATA_BLOB {
116         cbData: 0,
117         pbData: ptr::null_mut(),
118     };
119 
120     // SAFETY: the FFI call is safe because
121     // 1. ciphertext_blob lives longer than the call.
122     // 2. plaintext_blob lives longer than the call, and we later give ownership of the memory the
123     //    kernel allocates to LocalAllocBuffer which guarantees it is freed.
124     let res = unsafe {
125         CryptUnprotectData(
126             &mut ciphertext_blob as *mut _,
127             /* szDataDescr= */ ptr::null_mut(),
128             /* pOptionalEntropy= */ ptr::null_mut(),
129             /* pvReserved= */ ptr::null_mut(),
130             /* pPromptStruct */ ptr::null_mut(),
131             /* dwFlags */ 0,
132             &mut plaintext_blob as *mut _,
133         )
134     };
135     if res == 0 {
136         syscall_bail!("CryptUnprotectData failed");
137     }
138 
139     let plaintext_len: usize = plaintext_blob
140         .cbData
141         .try_into()
142         .context("resulting plaintext had an invalid size")?;
143 
144     // SAFETY: safe because plaintext_blob refers to a valid buffer of the specified length. This
145     // is guaranteed because CryptUnprotectData returned success.
146     Ok(unsafe { LocalAllocBuffer::new(plaintext_blob.pbData, plaintext_len) })
147 }
148 
149 #[cfg(test)]
150 mod tests {
151     use super::*;
152 
153     #[test]
encrypt_empty_string_is_valid()154     fn encrypt_empty_string_is_valid() {
155         let plaintext_str = "";
156         let mut plaintext_buffer = Vec::from(plaintext_str.as_bytes());
157 
158         let mut ciphertext_buffer = crypt_protect_data(plaintext_buffer.as_mut_slice()).unwrap();
159         let decrypted_plaintext_buffer =
160             crypt_unprotect_data(ciphertext_buffer.as_mut_slice()).unwrap();
161         let decrypted_plaintext_str =
162             std::str::from_utf8(decrypted_plaintext_buffer.as_slice()).unwrap();
163         assert_eq!(plaintext_str, decrypted_plaintext_str);
164     }
165 
166     #[test]
encrypt_decrypt_plaintext_matches()167     fn encrypt_decrypt_plaintext_matches() {
168         let plaintext_str = "test plaintext";
169         let mut plaintext_buffer = Vec::from(plaintext_str.as_bytes());
170 
171         let mut ciphertext_buffer = crypt_protect_data(plaintext_buffer.as_mut_slice()).unwrap();
172 
173         // If our plaintext & ciphertext are the same, something is very wrong.
174         assert_ne!(plaintext_str.as_bytes(), ciphertext_buffer.as_slice());
175 
176         // Decrypt the ciphertext and make sure it's our original plaintext.
177         let decrypted_plaintext_buffer =
178             crypt_unprotect_data(ciphertext_buffer.as_mut_slice()).unwrap();
179         let decrypted_plaintext_str =
180             std::str::from_utf8(decrypted_plaintext_buffer.as_slice()).unwrap();
181         assert_eq!(plaintext_str, decrypted_plaintext_str);
182     }
183 }
184