xref: /aosp_15_r20/external/crosvm/base/src/descriptor_reflection.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2020 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! Provides infrastructure for de/serializing descriptors embedded in Rust data structures.
6 //!
7 //! # Example
8 //!
9 //! ```
10 //! use serde_json::to_string;
11 //! use base::{
12 //!     FileSerdeWrapper, FromRawDescriptor, SafeDescriptor, SerializeDescriptors,
13 //!     deserialize_with_descriptors,
14 //! };
15 //! use tempfile::tempfile;
16 //!
17 //! let tmp_f = tempfile().unwrap();
18 //!
19 //! // Uses a simple wrapper to serialize a File because we can't implement Serialize for File.
20 //! let data = FileSerdeWrapper(tmp_f);
21 //!
22 //! // Wraps Serialize types to collect side channel descriptors as Serialize is called.
23 //! let data_wrapper = SerializeDescriptors::new(&data);
24 //!
25 //! // Use the wrapper with any serializer to serialize data is normal, grabbing descriptors
26 //! // as the data structures are serialized by the serializer.
27 //! let out_json = serde_json::to_string(&data_wrapper).expect("failed to serialize");
28 //!
29 //! // If data_wrapper contains any side channel descriptor refs
30 //! // (it contains tmp_f in this case), we can retrieve the actual descriptors
31 //! // from the side channel using into_descriptors().
32 //! let out_descriptors = data_wrapper.into_descriptors();
33 //!
34 //! // When sending out_json over some transport, also send out_descriptors.
35 //!
36 //! // For this example, we aren't really transporting data across the process, but we do need to
37 //! // convert the descriptor type.
38 //! let mut safe_descriptors = out_descriptors
39 //!     .iter()
40 //!     .map(|&v| unsafe { SafeDescriptor::from_raw_descriptor(v) });
41 //! std::mem::forget(data); // Prevent double drop of tmp_f.
42 //!
43 //! // The deserialize_with_descriptors function is used give the descriptor deserializers access
44 //! // to side channel descriptors.
45 //! let res: FileSerdeWrapper =
46 //!     deserialize_with_descriptors(|| serde_json::from_str(&out_json), safe_descriptors)
47 //!        .expect("failed to deserialize");
48 //! ```
49 
50 use std::cell::Cell;
51 use std::cell::RefCell;
52 use std::convert::TryInto;
53 use std::fmt;
54 use std::fs::File;
55 use std::ops::Deref;
56 use std::ops::DerefMut;
57 use std::panic::catch_unwind;
58 use std::panic::resume_unwind;
59 use std::panic::AssertUnwindSafe;
60 
61 use serde::de;
62 use serde::de::Error;
63 use serde::de::Visitor;
64 use serde::ser;
65 use serde::Deserialize;
66 use serde::Deserializer;
67 use serde::Serialize;
68 use serde::Serializer;
69 
70 use super::RawDescriptor;
71 use crate::descriptor::SafeDescriptor;
72 
73 thread_local! {
74     static DESCRIPTOR_DST: RefCell<Option<Vec<RawDescriptor>>> = Default::default();
75 }
76 
77 /// Initializes the thread local storage for descriptor serialization. Fails if it was already
78 /// initialized without an intervening `take_descriptor_dst` on this thread.
init_descriptor_dst() -> Result<(), &'static str>79 fn init_descriptor_dst() -> Result<(), &'static str> {
80     DESCRIPTOR_DST.with(|d| {
81         let mut descriptors = d.borrow_mut();
82         if descriptors.is_some() {
83             return Err(
84                 "attempt to initialize descriptor destination that was already initialized",
85             );
86         }
87         *descriptors = Some(Default::default());
88         Ok(())
89     })
90 }
91 
92 /// Takes the thread local storage for descriptor serialization. Fails if there wasn't a prior call
93 /// to `init_descriptor_dst` on this thread.
take_descriptor_dst() -> Result<Vec<RawDescriptor>, &'static str>94 fn take_descriptor_dst() -> Result<Vec<RawDescriptor>, &'static str> {
95     match DESCRIPTOR_DST.with(|d| d.replace(None)) {
96         Some(d) => Ok(d),
97         None => Err("attempt to take descriptor destination before it was initialized"),
98     }
99 }
100 
101 /// Pushes a descriptor on the thread local destination of descriptors, returning the index in which
102 /// the descriptor was pushed.
103 //
104 /// Returns Err if the thread local destination was not already initialized.
push_descriptor(rd: RawDescriptor) -> Result<usize, &'static str>105 fn push_descriptor(rd: RawDescriptor) -> Result<usize, &'static str> {
106     DESCRIPTOR_DST.with(|d| {
107         d.borrow_mut()
108             .as_mut()
109             .ok_or("attempt to serialize descriptor without descriptor destination")
110             .map(|descriptors| {
111                 let index = descriptors.len();
112                 descriptors.push(rd);
113                 index
114             })
115     })
116 }
117 
118 /// Serializes a descriptor for later retrieval in a parent `SerializeDescriptors` struct.
119 ///
120 /// If there is no parent `SerializeDescriptors` being serialized, this will return an error.
121 ///
122 /// For convenience, it is recommended to use the `with_raw_descriptor` module in a `#[serde(with =
123 /// "...")]` attribute which will make use of this function.
serialize_descriptor<S: Serializer>( rd: &RawDescriptor, se: S, ) -> std::result::Result<S::Ok, S::Error>124 pub fn serialize_descriptor<S: Serializer>(
125     rd: &RawDescriptor,
126     se: S,
127 ) -> std::result::Result<S::Ok, S::Error> {
128     let index = push_descriptor(*rd).map_err(ser::Error::custom)?;
129     se.serialize_u32(
130         index
131             .try_into()
132             .map_err(|_| ser::Error::custom("attempt to serialize too many descriptors at once"))?,
133     )
134 }
135 
136 /// Wrapper for a `Serialize` value which will capture any descriptors exported by the value when
137 /// given to an ordinary `Serializer`.
138 ///
139 /// This is the corresponding type to use for serialization before using
140 /// `deserialize_with_descriptors`.
141 ///
142 /// # Examples
143 ///
144 /// ```
145 /// use serde_json::to_string;
146 /// use base::{FileSerdeWrapper, SerializeDescriptors};
147 /// use tempfile::tempfile;
148 ///
149 /// let tmp_f = tempfile().unwrap();
150 /// let data = FileSerdeWrapper(tmp_f);
151 /// let data_wrapper = SerializeDescriptors::new(&data);
152 ///
153 /// // Serializes `v` as normal...
154 /// let out_json = serde_json::to_string(&data_wrapper).expect("failed to serialize");
155 /// // If `serialize_descriptor` was called, we can capture the descriptors from here.
156 /// let out_descriptors = data_wrapper.into_descriptors();
157 /// ```
158 pub struct SerializeDescriptors<'a, T: Serialize>(&'a T, Cell<Vec<RawDescriptor>>);
159 
160 impl<'a, T: Serialize> SerializeDescriptors<'a, T> {
new(inner: &'a T) -> Self161     pub fn new(inner: &'a T) -> Self {
162         Self(inner, Default::default())
163     }
164 
into_descriptors(self) -> Vec<RawDescriptor>165     pub fn into_descriptors(self) -> Vec<RawDescriptor> {
166         self.1.into_inner()
167     }
168 }
169 
170 impl<'a, T: Serialize> Serialize for SerializeDescriptors<'a, T> {
serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer,171     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
172     where
173         S: Serializer,
174     {
175         init_descriptor_dst().map_err(ser::Error::custom)?;
176 
177         // catch_unwind is used to ensure that init_descriptor_dst is always balanced with a call to
178         // take_descriptor_dst afterwards.
179         let res = catch_unwind(AssertUnwindSafe(|| self.0.serialize(serializer)));
180         self.1.set(take_descriptor_dst().unwrap());
181         match res {
182             Ok(r) => r,
183             Err(e) => resume_unwind(e),
184         }
185     }
186 }
187 
188 thread_local! {
189     static DESCRIPTOR_SRC: RefCell<Option<Vec<Option<SafeDescriptor>>>> = Default::default();
190 }
191 
192 /// Sets the thread local storage of descriptors for deserialization. Fails if this was already
193 /// called without a call to `take_descriptor_src` on this thread.
194 ///
195 /// This is given as a collection of `Option` so that unused descriptors can be returned.
set_descriptor_src(descriptors: Vec<Option<SafeDescriptor>>) -> Result<(), &'static str>196 fn set_descriptor_src(descriptors: Vec<Option<SafeDescriptor>>) -> Result<(), &'static str> {
197     DESCRIPTOR_SRC.with(|d| {
198         let mut src = d.borrow_mut();
199         if src.is_some() {
200             return Err("attempt to set descriptor source that was already set");
201         }
202         *src = Some(descriptors);
203         Ok(())
204     })
205 }
206 
207 /// Takes the thread local storage of descriptors for deserialization. Fails if the storage was
208 /// already taken or never set with `set_descriptor_src`.
209 ///
210 /// If deserialization was done, the descriptors will mostly come back as `None` unless some of them
211 /// were unused.
take_descriptor_src() -> Result<Vec<Option<SafeDescriptor>>, &'static str>212 fn take_descriptor_src() -> Result<Vec<Option<SafeDescriptor>>, &'static str> {
213     DESCRIPTOR_SRC.with(|d| {
214         d.replace(None)
215             .ok_or("attempt to take descriptor source which was never set")
216     })
217 }
218 
219 /// Takes a descriptor at the given index from the thread local source of descriptors.
220 //
221 /// Returns None if the thread local source was not already initialized.
take_descriptor(index: usize) -> Result<SafeDescriptor, &'static str>222 fn take_descriptor(index: usize) -> Result<SafeDescriptor, &'static str> {
223     DESCRIPTOR_SRC.with(|d| {
224         d.borrow_mut()
225             .as_mut()
226             .ok_or("attempt to deserialize descriptor without descriptor source")?
227             .get_mut(index)
228             .ok_or("attempt to deserialize out of bounds descriptor")?
229             .take()
230             .ok_or("attempt to deserialize descriptor that was already taken")
231     })
232 }
233 
234 /// Deserializes a descriptor provided via `deserialize_with_descriptors`.
235 ///
236 /// If `deserialize_with_descriptors` is not in the call chain, this will return an error.
237 ///
238 /// For convenience, it is recommended to use the `with_raw_descriptor` module in a `#[serde(with =
239 /// "...")]` attribute which will make use of this function.
deserialize_descriptor<'de, D>(de: D) -> std::result::Result<SafeDescriptor, D::Error> where D: Deserializer<'de>,240 pub fn deserialize_descriptor<'de, D>(de: D) -> std::result::Result<SafeDescriptor, D::Error>
241 where
242     D: Deserializer<'de>,
243 {
244     struct DescriptorVisitor;
245 
246     impl<'de> Visitor<'de> for DescriptorVisitor {
247         type Value = u32;
248 
249         fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
250             formatter.write_str("an integer which fits into a u32")
251         }
252 
253         fn visit_u8<E: de::Error>(self, value: u8) -> Result<Self::Value, E> {
254             Ok(value as _)
255         }
256 
257         fn visit_u16<E: de::Error>(self, value: u16) -> Result<Self::Value, E> {
258             Ok(value as _)
259         }
260 
261         fn visit_u32<E: de::Error>(self, value: u32) -> Result<Self::Value, E> {
262             Ok(value)
263         }
264 
265         fn visit_u64<E: de::Error>(self, value: u64) -> Result<Self::Value, E> {
266             value.try_into().map_err(E::custom)
267         }
268 
269         fn visit_u128<E: de::Error>(self, value: u128) -> Result<Self::Value, E> {
270             value.try_into().map_err(E::custom)
271         }
272 
273         fn visit_i8<E: de::Error>(self, value: i8) -> Result<Self::Value, E> {
274             value.try_into().map_err(E::custom)
275         }
276 
277         fn visit_i16<E: de::Error>(self, value: i16) -> Result<Self::Value, E> {
278             value.try_into().map_err(E::custom)
279         }
280 
281         fn visit_i32<E: de::Error>(self, value: i32) -> Result<Self::Value, E> {
282             value.try_into().map_err(E::custom)
283         }
284 
285         fn visit_i64<E: de::Error>(self, value: i64) -> Result<Self::Value, E> {
286             value.try_into().map_err(E::custom)
287         }
288 
289         fn visit_i128<E: de::Error>(self, value: i128) -> Result<Self::Value, E> {
290             value.try_into().map_err(E::custom)
291         }
292     }
293 
294     let index = de.deserialize_u32(DescriptorVisitor)? as usize;
295     take_descriptor(index).map_err(D::Error::custom)
296 }
297 
298 /// Allows the use of any serde deserializer within a closure while providing access to the a set of
299 /// descriptors for use in `deserialize_descriptor`.
300 ///
301 /// This is the corresponding call to use deserialize after using `SerializeDescriptors`.
302 ///
303 /// If `deserialize_with_descriptors` is called anywhere within the given closure, it return an
304 /// error.
deserialize_with_descriptors<F, T, E>( f: F, descriptors: impl IntoIterator<Item = SafeDescriptor>, ) -> Result<T, E> where F: FnOnce() -> Result<T, E>, E: de::Error,305 pub fn deserialize_with_descriptors<F, T, E>(
306     f: F,
307     descriptors: impl IntoIterator<Item = SafeDescriptor>,
308 ) -> Result<T, E>
309 where
310     F: FnOnce() -> Result<T, E>,
311     E: de::Error,
312 {
313     let descriptor_src = descriptors.into_iter().map(Option::Some).collect();
314     set_descriptor_src(descriptor_src).map_err(E::custom)?;
315 
316     // catch_unwind is used to ensure that set_descriptor_src is always balanced with a call to
317     // take_descriptor_src afterwards.
318     let res = catch_unwind(AssertUnwindSafe(f));
319 
320     // unwrap is used because set_descriptor_src is always called before this, so it should never
321     // panic.
322     let empty_descriptors = take_descriptor_src().unwrap();
323 
324     // The deserializer should have consumed every descriptor.
325     debug_assert!(empty_descriptors.into_iter().all(|d| d.is_none()));
326 
327     match res {
328         Ok(r) => r,
329         Err(e) => resume_unwind(e),
330     }
331 }
332 
333 /// Module that exports `serialize`/`deserialize` functions for use with `#[serde(with = "...")]`
334 /// attribute. It only works with fields with `RawDescriptor` type.
335 ///
336 /// # Examples
337 ///
338 /// ```
339 /// use serde::{Deserialize, Serialize};
340 /// use base::RawDescriptor;
341 ///
342 /// #[derive(Serialize, Deserialize)]
343 /// struct RawContainer {
344 ///     #[serde(with = "base::with_raw_descriptor")]
345 ///     rd: RawDescriptor,
346 /// }
347 /// ```
348 pub mod with_raw_descriptor {
349     use serde::Deserializer;
350 
351     use super::super::RawDescriptor;
352     pub use super::serialize_descriptor as serialize;
353     use crate::descriptor::IntoRawDescriptor;
354 
deserialize<'de, D>(de: D) -> std::result::Result<RawDescriptor, D::Error> where D: Deserializer<'de>,355     pub fn deserialize<'de, D>(de: D) -> std::result::Result<RawDescriptor, D::Error>
356     where
357         D: Deserializer<'de>,
358     {
359         super::deserialize_descriptor(de).map(IntoRawDescriptor::into_raw_descriptor)
360     }
361 }
362 
363 /// Module that exports `serialize`/`deserialize` functions for use with `#[serde(with = "...")]`
364 /// attribute.
365 ///
366 /// # Examples
367 ///
368 /// ```
369 /// use std::fs::File;
370 /// use serde::{Deserialize, Serialize};
371 /// use base::RawDescriptor;
372 ///
373 /// #[derive(Serialize, Deserialize)]
374 /// struct FileContainer {
375 ///     #[serde(with = "base::with_as_descriptor")]
376 ///     file: File,
377 /// }
378 /// ```
379 pub mod with_as_descriptor {
380     use serde::Deserializer;
381     use serde::Serializer;
382 
383     use crate::descriptor::AsRawDescriptor;
384     use crate::descriptor::FromRawDescriptor;
385     use crate::descriptor::IntoRawDescriptor;
386 
serialize<S: Serializer>( rd: &dyn AsRawDescriptor, se: S, ) -> std::result::Result<S::Ok, S::Error>387     pub fn serialize<S: Serializer>(
388         rd: &dyn AsRawDescriptor,
389         se: S,
390     ) -> std::result::Result<S::Ok, S::Error> {
391         super::serialize_descriptor(&rd.as_raw_descriptor(), se)
392     }
393 
deserialize<'de, D, T>(de: D) -> std::result::Result<T, D::Error> where D: Deserializer<'de>, T: FromRawDescriptor,394     pub fn deserialize<'de, D, T>(de: D) -> std::result::Result<T, D::Error>
395     where
396         D: Deserializer<'de>,
397         T: FromRawDescriptor,
398     {
399         super::deserialize_descriptor(de)
400             .map(IntoRawDescriptor::into_raw_descriptor)
401             .map(|rd|
402                 // SAFETY: rd is expected to be valid for the duration of the call.
403                 unsafe { T::from_raw_descriptor(rd) })
404     }
405 }
406 
407 /// A simple wrapper around `File` that implements `Serialize`/`Deserialize`, which is useful when
408 /// the `#[serde(with = "with_as_descriptor")]` trait is infeasible, such as for a field with type
409 /// `Option<File>`.
410 #[derive(Serialize, Deserialize)]
411 #[serde(transparent)]
412 pub struct FileSerdeWrapper(#[serde(with = "with_as_descriptor")] pub File);
413 
414 impl fmt::Debug for FileSerdeWrapper {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result415     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
416         self.0.fmt(f)
417     }
418 }
419 
420 impl From<File> for FileSerdeWrapper {
from(file: File) -> Self421     fn from(file: File) -> Self {
422         FileSerdeWrapper(file)
423     }
424 }
425 
426 impl From<FileSerdeWrapper> for File {
from(f: FileSerdeWrapper) -> File427     fn from(f: FileSerdeWrapper) -> File {
428         f.0
429     }
430 }
431 
432 impl Deref for FileSerdeWrapper {
433     type Target = File;
deref(&self) -> &Self::Target434     fn deref(&self) -> &Self::Target {
435         &self.0
436     }
437 }
438 
439 impl DerefMut for FileSerdeWrapper {
deref_mut(&mut self) -> &mut Self::Target440     fn deref_mut(&mut self) -> &mut Self::Target {
441         &mut self.0
442     }
443 }
444 
445 #[cfg(test)]
446 mod tests {
447     use std::collections::HashMap;
448     use std::fs::File;
449     use std::mem::ManuallyDrop;
450 
451     use serde::de::DeserializeOwned;
452     use serde::Deserialize;
453     use serde::Serialize;
454     use tempfile::tempfile;
455 
456     use super::super::deserialize_with_descriptors;
457     use super::super::with_as_descriptor;
458     use super::super::with_raw_descriptor;
459     use super::super::AsRawDescriptor;
460     use super::super::FileSerdeWrapper;
461     use super::super::FromRawDescriptor;
462     use super::super::RawDescriptor;
463     use super::super::SafeDescriptor;
464     use super::super::SerializeDescriptors;
465 
deserialize<T: DeserializeOwned>(json: &str, descriptors: &[RawDescriptor]) -> T466     fn deserialize<T: DeserializeOwned>(json: &str, descriptors: &[RawDescriptor]) -> T {
467         let safe_descriptors = descriptors.iter().map(|&v|
468                 // SAFETY: `descriptor` is expected to be valid.
469                 unsafe { SafeDescriptor::from_raw_descriptor(v) });
470 
471         deserialize_with_descriptors(|| serde_json::from_str(json), safe_descriptors).unwrap()
472     }
473 
474     #[test]
raw()475     fn raw() {
476         #[derive(Serialize, Deserialize, PartialEq, Debug)]
477         struct RawContainer {
478             #[serde(with = "with_raw_descriptor")]
479             rd: RawDescriptor,
480         }
481         // Specifically chosen to not overlap a real descriptor to avoid having to allocate any
482         // descriptors for this test.
483         let fake_rd = 5_123_457_i32;
484         let v = RawContainer {
485             rd: fake_rd as RawDescriptor,
486         };
487         let v_serialize = SerializeDescriptors::new(&v);
488         let json = serde_json::to_string(&v_serialize).unwrap();
489         let descriptors = v_serialize.into_descriptors();
490         let res = deserialize(&json, &descriptors);
491         assert_eq!(v, res);
492     }
493 
494     #[test]
file()495     fn file() {
496         #[derive(Serialize, Deserialize)]
497         struct FileContainer {
498             #[serde(with = "with_as_descriptor")]
499             file: File,
500         }
501 
502         let v = FileContainer {
503             file: tempfile().unwrap(),
504         };
505         let v_serialize = SerializeDescriptors::new(&v);
506         let json = serde_json::to_string(&v_serialize).unwrap();
507         let descriptors = v_serialize.into_descriptors();
508         let v = ManuallyDrop::new(v);
509         let res: FileContainer = deserialize(&json, &descriptors);
510         assert_eq!(v.file.as_raw_descriptor(), res.file.as_raw_descriptor());
511     }
512 
513     #[test]
option()514     fn option() {
515         #[derive(Serialize, Deserialize)]
516         struct TestOption {
517             a: Option<FileSerdeWrapper>,
518             b: Option<FileSerdeWrapper>,
519         }
520 
521         let v = TestOption {
522             a: None,
523             b: Some(tempfile().unwrap().into()),
524         };
525         let v_serialize = SerializeDescriptors::new(&v);
526         let json = serde_json::to_string(&v_serialize).unwrap();
527         let descriptors = v_serialize.into_descriptors();
528         let v = ManuallyDrop::new(v);
529         let res: TestOption = deserialize(&json, &descriptors);
530         assert!(res.a.is_none());
531         assert!(res.b.is_some());
532         assert_eq!(
533             v.b.as_ref().unwrap().as_raw_descriptor(),
534             res.b.unwrap().as_raw_descriptor()
535         );
536     }
537 
538     #[test]
map()539     fn map() {
540         let mut v: HashMap<String, FileSerdeWrapper> = HashMap::new();
541         v.insert("a".into(), tempfile().unwrap().into());
542         v.insert("b".into(), tempfile().unwrap().into());
543         v.insert("c".into(), tempfile().unwrap().into());
544         let v_serialize = SerializeDescriptors::new(&v);
545         let json = serde_json::to_string(&v_serialize).unwrap();
546         let descriptors = v_serialize.into_descriptors();
547         // Prevent the files in `v` from dropping while allowing the HashMap itself to drop. It is
548         // done this way to prevent a double close of the files (which should reside in `res`)
549         // without triggering the leak sanitizer on `v`'s HashMap heap memory.
550         let v: HashMap<_, _> = v
551             .into_iter()
552             .map(|(k, v)| (k, ManuallyDrop::new(v)))
553             .collect();
554         let res: HashMap<String, FileSerdeWrapper> = deserialize(&json, &descriptors);
555 
556         assert_eq!(v.len(), res.len());
557         for (k, v) in v.iter() {
558             assert_eq!(
559                 res.get(k).unwrap().as_raw_descriptor(),
560                 v.as_raw_descriptor()
561             );
562         }
563     }
564 }
565