1 // Copyright 2024 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #![deny(unsafe_op_in_unsafe_fn)]
6
7 //! Safe, Rusty wrappers around DPAPI.
8
9 use std::ffi::c_void;
10 use std::ptr;
11 use std::slice;
12
13 use anyhow::Context;
14 use anyhow::Result;
15 use winapi::um::dpapi::CryptProtectData;
16 use winapi::um::dpapi::CryptUnprotectData;
17 use winapi::um::winbase::LocalFree;
18 use winapi::um::wincrypt::DATA_BLOB;
19 use zeroize::Zeroize;
20
21 use crate::syscall_bail;
22
23 /// Wrapper around buffers allocated by DPAPI that can be freed with LocalFree.
24 pub struct LocalAllocBuffer {
25 ptr: *mut u8,
26 len: usize,
27 }
28
29 impl LocalAllocBuffer {
30 /// # Safety
31 /// 0. ptr is a valid buffer of length len and is safe to free with LocalFree.
32 /// 1. The caller transfers ownership of the buffer to this object on construction.
new(ptr: *mut u8, len: usize) -> Self33 unsafe fn new(ptr: *mut u8, len: usize) -> Self {
34 Self { ptr, len }
35 }
36
as_mut_slice(&mut self) -> &mut [u8]37 pub fn as_mut_slice(&mut self) -> &mut [u8] {
38 // SAFETY: ptr is a pointer to a buffer of length len.
39 unsafe { slice::from_raw_parts_mut(self.ptr, self.len) }
40 }
41
as_slice(&self) -> &[u8]42 pub fn as_slice(&self) -> &[u8] {
43 // SAFETY: ptr is a pointer to a buffer of length len.
44 unsafe { slice::from_raw_parts(self.ptr, self.len) }
45 }
46 }
47
48 impl Drop for LocalAllocBuffer {
drop(&mut self)49 fn drop(&mut self) {
50 // This buffer likely contains cryptographic key material. Zero it.
51 self.as_mut_slice().zeroize();
52
53 // SAFETY: when this struct is created, the caller guarantees
54 // ptr is a valid pointer to a buffer that can be freed with LocalFree.
55 unsafe {
56 LocalFree(self.ptr as *mut c_void);
57 }
58 }
59 }
60
61 /// # Summary
62 /// Wrapper around CryptProtectData that displays no UI.
crypt_protect_data(plaintext: &mut [u8]) -> Result<LocalAllocBuffer>63 pub fn crypt_protect_data(plaintext: &mut [u8]) -> Result<LocalAllocBuffer> {
64 let mut plaintext_blob = DATA_BLOB {
65 cbData: plaintext
66 .len()
67 .try_into()
68 .context("plaintext size won't fit in DWORD")?,
69 pbData: plaintext.as_mut_ptr(),
70 };
71 let mut ciphertext_blob = DATA_BLOB {
72 cbData: 0,
73 pbData: ptr::null_mut(),
74 };
75
76 // SAFETY: the FFI call is safe because
77 // 1. plaintext_blob lives longer than the call.
78 // 2. ciphertext_blob lives longer than the call, and we later give ownership of the memory the
79 // kernel allocates to LocalAllocBuffer which guarantees it is freed.
80 let res = unsafe {
81 CryptProtectData(
82 &mut plaintext_blob as *mut _,
83 /* szDataDescr= */ ptr::null_mut(),
84 /* pOptionalEntropy= */ ptr::null_mut(),
85 /* pvReserved= */ ptr::null_mut(),
86 /* pPromptStruct */ ptr::null_mut(),
87 /* dwFlags */ 0,
88 &mut ciphertext_blob as *mut _,
89 )
90 };
91 if res == 0 {
92 syscall_bail!("CryptProtectData failed");
93 }
94
95 let ciphertext_len: usize = ciphertext_blob
96 .cbData
97 .try_into()
98 .context("resulting ciphertext had an invalid size")?;
99
100 // SAFETY: safe because ciphertext_blob refers to a valid buffer of the specified length. This
101 // is guaranteed because CryptProtectData returned success.
102 Ok(unsafe { LocalAllocBuffer::new(ciphertext_blob.pbData, ciphertext_len) })
103 }
104
105 /// # Summary
106 /// Wrapper around CryptProtectData that displays no UI.
crypt_unprotect_data(ciphertext: &mut [u8]) -> Result<LocalAllocBuffer>107 pub fn crypt_unprotect_data(ciphertext: &mut [u8]) -> Result<LocalAllocBuffer> {
108 let mut ciphertext_blob = DATA_BLOB {
109 cbData: ciphertext
110 .len()
111 .try_into()
112 .context("plaintext size won't fit in DWORD")?,
113 pbData: ciphertext.as_mut_ptr(),
114 };
115 let mut plaintext_blob = DATA_BLOB {
116 cbData: 0,
117 pbData: ptr::null_mut(),
118 };
119
120 // SAFETY: the FFI call is safe because
121 // 1. ciphertext_blob lives longer than the call.
122 // 2. plaintext_blob lives longer than the call, and we later give ownership of the memory the
123 // kernel allocates to LocalAllocBuffer which guarantees it is freed.
124 let res = unsafe {
125 CryptUnprotectData(
126 &mut ciphertext_blob as *mut _,
127 /* szDataDescr= */ ptr::null_mut(),
128 /* pOptionalEntropy= */ ptr::null_mut(),
129 /* pvReserved= */ ptr::null_mut(),
130 /* pPromptStruct */ ptr::null_mut(),
131 /* dwFlags */ 0,
132 &mut plaintext_blob as *mut _,
133 )
134 };
135 if res == 0 {
136 syscall_bail!("CryptUnprotectData failed");
137 }
138
139 let plaintext_len: usize = plaintext_blob
140 .cbData
141 .try_into()
142 .context("resulting plaintext had an invalid size")?;
143
144 // SAFETY: safe because plaintext_blob refers to a valid buffer of the specified length. This
145 // is guaranteed because CryptUnprotectData returned success.
146 Ok(unsafe { LocalAllocBuffer::new(plaintext_blob.pbData, plaintext_len) })
147 }
148
149 #[cfg(test)]
150 mod tests {
151 use super::*;
152
153 #[test]
encrypt_empty_string_is_valid()154 fn encrypt_empty_string_is_valid() {
155 let plaintext_str = "";
156 let mut plaintext_buffer = Vec::from(plaintext_str.as_bytes());
157
158 let mut ciphertext_buffer = crypt_protect_data(plaintext_buffer.as_mut_slice()).unwrap();
159 let decrypted_plaintext_buffer =
160 crypt_unprotect_data(ciphertext_buffer.as_mut_slice()).unwrap();
161 let decrypted_plaintext_str =
162 std::str::from_utf8(decrypted_plaintext_buffer.as_slice()).unwrap();
163 assert_eq!(plaintext_str, decrypted_plaintext_str);
164 }
165
166 #[test]
encrypt_decrypt_plaintext_matches()167 fn encrypt_decrypt_plaintext_matches() {
168 let plaintext_str = "test plaintext";
169 let mut plaintext_buffer = Vec::from(plaintext_str.as_bytes());
170
171 let mut ciphertext_buffer = crypt_protect_data(plaintext_buffer.as_mut_slice()).unwrap();
172
173 // If our plaintext & ciphertext are the same, something is very wrong.
174 assert_ne!(plaintext_str.as_bytes(), ciphertext_buffer.as_slice());
175
176 // Decrypt the ciphertext and make sure it's our original plaintext.
177 let decrypted_plaintext_buffer =
178 crypt_unprotect_data(ciphertext_buffer.as_mut_slice()).unwrap();
179 let decrypted_plaintext_str =
180 std::str::from_utf8(decrypted_plaintext_buffer.as_slice()).unwrap();
181 assert_eq!(plaintext_str, decrypted_plaintext_str);
182 }
183 }
184