xref: /aosp_15_r20/external/crosvm/devices/src/virtio/vhost_user_frontend/mod.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! VirtioDevice implementation for the VMM side of a vhost-user connection.
6 
7 mod error;
8 mod fs;
9 mod handler;
10 mod sys;
11 mod worker;
12 
13 use std::cell::RefCell;
14 use std::collections::BTreeMap;
15 use std::io::Read;
16 use std::io::Write;
17 use std::sync::Arc;
18 
19 use anyhow::bail;
20 use anyhow::Context;
21 use base::error;
22 use base::trace;
23 use base::AsRawDescriptor;
24 use base::Event;
25 use base::RawDescriptor;
26 use base::WorkerThread;
27 use serde_json::Value;
28 use sync::Mutex;
29 use vm_memory::GuestMemory;
30 use vmm_vhost::message::VhostUserConfigFlags;
31 use vmm_vhost::message::VhostUserMigrationPhase;
32 use vmm_vhost::message::VhostUserProtocolFeatures;
33 use vmm_vhost::message::VhostUserTransferDirection;
34 use vmm_vhost::BackendClient;
35 use vmm_vhost::VhostUserMemoryRegionInfo;
36 use vmm_vhost::VringConfigData;
37 use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
38 
39 use crate::virtio::copy_config;
40 use crate::virtio::device_constants::VIRTIO_DEVICE_TYPE_SPECIFIC_FEATURES_MASK;
41 use crate::virtio::vhost_user_frontend::error::Error;
42 use crate::virtio::vhost_user_frontend::error::Result;
43 use crate::virtio::vhost_user_frontend::handler::BackendReqHandler;
44 use crate::virtio::vhost_user_frontend::handler::BackendReqHandlerImpl;
45 use crate::virtio::vhost_user_frontend::sys::create_backend_req_handler;
46 use crate::virtio::vhost_user_frontend::worker::Worker;
47 use crate::virtio::DeviceType;
48 use crate::virtio::Interrupt;
49 use crate::virtio::Queue;
50 use crate::virtio::SharedMemoryMapper;
51 use crate::virtio::SharedMemoryRegion;
52 use crate::virtio::VirtioDevice;
53 use crate::PciAddress;
54 
55 pub struct VhostUserFrontend {
56     device_type: DeviceType,
57     worker_thread: Option<WorkerThread<Option<BackendReqHandler>>>,
58 
59     backend_client: Arc<Mutex<BackendClient>>,
60     avail_features: u64,
61     acked_features: u64,
62     protocol_features: VhostUserProtocolFeatures,
63     // `backend_req_handler` is only present if the backend supports BACKEND_REQ. `worker_thread`
64     // takes ownership of `backend_req_handler` when it starts. The worker thread will always
65     // return ownershp of the handler when stopped.
66     backend_req_handler: Option<BackendReqHandler>,
67     // Shared memory region info. IPC result from backend is saved with outer Option.
68     shmem_region: RefCell<Option<Option<SharedMemoryRegion>>>,
69 
70     queue_sizes: Vec<u16>,
71     cfg: Option<Vec<u8>>,
72     expose_shmem_descriptors_with_viommu: bool,
73     pci_address: Option<PciAddress>,
74 
75     // Queues that have been sent to the backend. Always `Some` when active and not asleep. Saved
76     // for use in `virtio_sleep`. Since the backend is managing them, the local state of the queue
77     // is likely stale.
78     sent_queues: Option<BTreeMap<usize, Queue>>,
79 }
80 
81 // Returns the largest power of two that is less than or equal to `val`.
power_of_two_le(val: u16) -> Option<u16>82 fn power_of_two_le(val: u16) -> Option<u16> {
83     if val == 0 {
84         None
85     } else if val.is_power_of_two() {
86         Some(val)
87     } else {
88         val.checked_next_power_of_two()
89             .map(|next_pow_two| next_pow_two / 2)
90     }
91 }
92 
93 impl VhostUserFrontend {
94     /// Create a new VirtioDevice for a vhost-user device frontend.
95     ///
96     /// # Arguments
97     ///
98     /// - `device_type`: virtio device type
99     /// - `base_features`: base virtio device features (e.g. `VIRTIO_F_VERSION_1`)
100     /// - `connection`: connection to the device backend
101     /// - `max_queue_size`: maximum number of entries in each queue (default: [`Queue::MAX_SIZE`])
new( device_type: DeviceType, base_features: u64, connection: vmm_vhost::Connection<vmm_vhost::FrontendReq>, max_queue_size: Option<u16>, pci_address: Option<PciAddress>, ) -> Result<VhostUserFrontend>102     pub fn new(
103         device_type: DeviceType,
104         base_features: u64,
105         connection: vmm_vhost::Connection<vmm_vhost::FrontendReq>,
106         max_queue_size: Option<u16>,
107         pci_address: Option<PciAddress>,
108     ) -> Result<VhostUserFrontend> {
109         VhostUserFrontend::new_internal(
110             connection,
111             device_type,
112             max_queue_size,
113             base_features,
114             None, // cfg
115             pci_address,
116         )
117     }
118 
119     /// Create a new VirtioDevice for a vhost-user device frontend.
120     ///
121     /// # Arguments
122     ///
123     /// - `connection`: connection to the device backend
124     /// - `device_type`: virtio device type
125     /// - `max_queue_size`: maximum number of entries in each queue (default: [`Queue::MAX_SIZE`])
126     /// - `base_features`: base virtio device features (e.g. `VIRTIO_F_VERSION_1`)
127     /// - `cfg`: bytes to return for the virtio configuration space (queried from device if not
128     ///   specified)
new_internal( connection: vmm_vhost::Connection<vmm_vhost::FrontendReq>, device_type: DeviceType, max_queue_size: Option<u16>, mut base_features: u64, cfg: Option<&[u8]>, pci_address: Option<PciAddress>, ) -> Result<VhostUserFrontend>129     pub(crate) fn new_internal(
130         connection: vmm_vhost::Connection<vmm_vhost::FrontendReq>,
131         device_type: DeviceType,
132         max_queue_size: Option<u16>,
133         mut base_features: u64,
134         cfg: Option<&[u8]>,
135         pci_address: Option<PciAddress>,
136     ) -> Result<VhostUserFrontend> {
137         // Don't allow packed queues even if requested. We don't handle them properly yet at the
138         // protocol layer.
139         // TODO: b/331466964 - Remove once packed queue support is added to BackendClient.
140         if base_features & (1 << virtio_sys::virtio_config::VIRTIO_F_RING_PACKED) != 0 {
141             base_features &= !(1 << virtio_sys::virtio_config::VIRTIO_F_RING_PACKED);
142             base::warn!(
143                 "VIRTIO_F_RING_PACKED requested, but not yet supported by vhost-user frontend. \
144                 Automatically disabled."
145             );
146         }
147 
148         #[cfg(windows)]
149         let backend_pid = connection.target_pid();
150 
151         let mut backend_client = BackendClient::new(connection);
152 
153         backend_client.set_owner().map_err(Error::SetOwner)?;
154 
155         let allow_features = VIRTIO_DEVICE_TYPE_SPECIFIC_FEATURES_MASK
156             | base_features
157             | 1 << VHOST_USER_F_PROTOCOL_FEATURES;
158         let avail_features =
159             allow_features & backend_client.get_features().map_err(Error::GetFeatures)?;
160         let mut acked_features = 0;
161 
162         let mut allow_protocol_features = VhostUserProtocolFeatures::CONFIG
163             | VhostUserProtocolFeatures::MQ
164             | VhostUserProtocolFeatures::BACKEND_REQ
165             | VhostUserProtocolFeatures::DEVICE_STATE;
166 
167         // HACK: the crosvm vhost-user GPU backend supports the non-standard
168         // VHOST_USER_PROTOCOL_FEATURE_SHARED_MEMORY_REGIONS. This should either be standardized
169         // (and enabled for all device types) or removed.
170         let expose_shmem_descriptors_with_viommu = if device_type == DeviceType::Gpu {
171             allow_protocol_features |= VhostUserProtocolFeatures::SHARED_MEMORY_REGIONS;
172             true
173         } else {
174             false
175         };
176 
177         let mut protocol_features = VhostUserProtocolFeatures::empty();
178         if avail_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 {
179             // The vhost-user backend supports VHOST_USER_F_PROTOCOL_FEATURES; enable it.
180             backend_client
181                 .set_features(1 << VHOST_USER_F_PROTOCOL_FEATURES)
182                 .map_err(Error::SetFeatures)?;
183             acked_features |= 1 << VHOST_USER_F_PROTOCOL_FEATURES;
184 
185             let avail_protocol_features = backend_client
186                 .get_protocol_features()
187                 .map_err(Error::GetProtocolFeatures)?;
188             protocol_features = allow_protocol_features & avail_protocol_features;
189             backend_client
190                 .set_protocol_features(protocol_features)
191                 .map_err(Error::SetProtocolFeatures)?;
192         }
193 
194         // if protocol feature `VhostUserProtocolFeatures::BACKEND_REQ` is negotiated.
195         let backend_req_handler =
196             if protocol_features.contains(VhostUserProtocolFeatures::BACKEND_REQ) {
197                 let (handler, tx_fd) = create_backend_req_handler(
198                     BackendReqHandlerImpl::new(),
199                     #[cfg(windows)]
200                     backend_pid,
201                 )?;
202                 backend_client
203                     .set_backend_req_fd(&tx_fd)
204                     .map_err(Error::SetDeviceRequestChannel)?;
205                 Some(handler)
206             } else {
207                 None
208             };
209 
210         // If the device supports VHOST_USER_PROTOCOL_F_MQ, use VHOST_USER_GET_QUEUE_NUM to
211         // determine the number of queues supported. Otherwise, use the minimum number of queues
212         // required by the spec for this device type.
213         let num_queues = if protocol_features.contains(VhostUserProtocolFeatures::MQ) {
214             trace!("backend supports VHOST_USER_PROTOCOL_F_MQ");
215             let num_queues = backend_client.get_queue_num().map_err(Error::GetQueueNum)?;
216             trace!("VHOST_USER_GET_QUEUE_NUM returned {num_queues}");
217             num_queues as usize
218         } else {
219             trace!("backend does not support VHOST_USER_PROTOCOL_F_MQ");
220             device_type.min_queues()
221         };
222 
223         // Clamp the maximum queue size to the largest power of 2 <= max_queue_size.
224         let max_queue_size = max_queue_size
225             .and_then(power_of_two_le)
226             .unwrap_or(Queue::MAX_SIZE);
227 
228         trace!(
229             "vhost-user {device_type} frontend with {num_queues} queues x {max_queue_size} entries\
230             {}",
231             if let Some(pci_address) = pci_address {
232                 format!(" pci-address {pci_address}")
233             } else {
234                 "".to_string()
235             }
236         );
237 
238         let queue_sizes = vec![max_queue_size; num_queues];
239 
240         Ok(VhostUserFrontend {
241             device_type,
242             worker_thread: None,
243             backend_client: Arc::new(Mutex::new(backend_client)),
244             avail_features,
245             acked_features,
246             protocol_features,
247             backend_req_handler,
248             shmem_region: RefCell::new(None),
249             queue_sizes,
250             cfg: cfg.map(|cfg| cfg.to_vec()),
251             expose_shmem_descriptors_with_viommu,
252             pci_address,
253             sent_queues: None,
254         })
255     }
256 
set_mem_table(&mut self, mem: &GuestMemory) -> Result<()>257     fn set_mem_table(&mut self, mem: &GuestMemory) -> Result<()> {
258         let regions: Vec<_> = mem
259             .regions()
260             .map(|region| VhostUserMemoryRegionInfo {
261                 guest_phys_addr: region.guest_addr.0,
262                 memory_size: region.size as u64,
263                 userspace_addr: region.host_addr as u64,
264                 mmap_offset: region.shm_offset,
265                 mmap_handle: region.shm.as_raw_descriptor(),
266             })
267             .collect();
268 
269         self.backend_client
270             .lock()
271             .set_mem_table(regions.as_slice())
272             .map_err(Error::SetMemTable)?;
273 
274         Ok(())
275     }
276 
277     /// Activates a vring for the given `queue`.
activate_vring( &mut self, mem: &GuestMemory, queue_index: usize, queue: &Queue, irqfd: &Event, ) -> Result<()>278     fn activate_vring(
279         &mut self,
280         mem: &GuestMemory,
281         queue_index: usize,
282         queue: &Queue,
283         irqfd: &Event,
284     ) -> Result<()> {
285         let backend_client = self.backend_client.lock();
286         backend_client
287             .set_vring_num(queue_index, queue.size())
288             .map_err(Error::SetVringNum)?;
289 
290         let config_data = VringConfigData {
291             queue_size: queue.size(),
292             flags: 0u32,
293             desc_table_addr: mem
294                 .get_host_address(queue.desc_table())
295                 .map_err(Error::GetHostAddress)? as u64,
296             used_ring_addr: mem
297                 .get_host_address(queue.used_ring())
298                 .map_err(Error::GetHostAddress)? as u64,
299             avail_ring_addr: mem
300                 .get_host_address(queue.avail_ring())
301                 .map_err(Error::GetHostAddress)? as u64,
302             log_addr: None,
303         };
304         backend_client
305             .set_vring_addr(queue_index, &config_data)
306             .map_err(Error::SetVringAddr)?;
307 
308         backend_client
309             .set_vring_base(queue_index, queue.next_avail_to_process())
310             .map_err(Error::SetVringBase)?;
311 
312         backend_client
313             .set_vring_call(queue_index, irqfd)
314             .map_err(Error::SetVringCall)?;
315         backend_client
316             .set_vring_kick(queue_index, queue.event())
317             .map_err(Error::SetVringKick)?;
318 
319         // Per protocol documentation, `VHOST_USER_SET_VRING_ENABLE` should be sent only when
320         // `VHOST_USER_F_PROTOCOL_FEATURES` has been negotiated.
321         if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 {
322             backend_client
323                 .set_vring_enable(queue_index, true)
324                 .map_err(Error::SetVringEnable)?;
325         }
326 
327         Ok(())
328     }
329 
330     /// Stops the vring for the given `queue`, returning its base index.
deactivate_vring(&self, queue_index: usize) -> Result<u16>331     fn deactivate_vring(&self, queue_index: usize) -> Result<u16> {
332         let backend_client = self.backend_client.lock();
333 
334         if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 {
335             backend_client
336                 .set_vring_enable(queue_index, false)
337                 .map_err(Error::SetVringEnable)?;
338         }
339 
340         let vring_base = backend_client
341             .get_vring_base(queue_index)
342             .map_err(Error::GetVringBase)?;
343 
344         vring_base
345             .try_into()
346             .map_err(|_| Error::VringBaseTooBig(vring_base))
347     }
348 
349     /// Helper to start up the worker thread that will be used with handling interrupts and requests
350     /// from the device process.
start_worker(&mut self, interrupt: Interrupt, non_msix_evt: Event)351     fn start_worker(&mut self, interrupt: Interrupt, non_msix_evt: Event) {
352         assert!(
353             self.worker_thread.is_none(),
354             "BUG: attempted to start worker twice"
355         );
356 
357         let label = self.debug_label();
358 
359         let mut backend_req_handler = self.backend_req_handler.take();
360         if let Some(handler) = &mut backend_req_handler {
361             // Using unwrap here to get the mutex protected value
362             handler.frontend_mut().set_interrupt(interrupt.clone());
363         }
364 
365         let backend_client = self.backend_client.clone();
366 
367         self.worker_thread = Some(WorkerThread::start(label.clone(), move |kill_evt| {
368             let mut worker = Worker {
369                 kill_evt,
370                 non_msix_evt,
371                 backend_req_handler,
372                 backend_client,
373             };
374             worker
375                 .run(interrupt)
376                 .with_context(|| format!("{label}: vhost_user_frontend worker failed"))
377                 .unwrap();
378             worker.backend_req_handler
379         }));
380     }
381 }
382 
383 impl VirtioDevice for VhostUserFrontend {
384     // Override the default debug label to differentiate vhost-user devices from virtio.
debug_label(&self) -> String385     fn debug_label(&self) -> String {
386         format!("vu-{}", self.device_type())
387     }
388 
keep_rds(&self) -> Vec<RawDescriptor>389     fn keep_rds(&self) -> Vec<RawDescriptor> {
390         Vec::new()
391     }
392 
device_type(&self) -> DeviceType393     fn device_type(&self) -> DeviceType {
394         self.device_type
395     }
396 
queue_max_sizes(&self) -> &[u16]397     fn queue_max_sizes(&self) -> &[u16] {
398         &self.queue_sizes
399     }
400 
features(&self) -> u64401     fn features(&self) -> u64 {
402         self.avail_features
403     }
404 
ack_features(&mut self, features: u64)405     fn ack_features(&mut self, features: u64) {
406         let features = (features & self.avail_features) | self.acked_features;
407         if let Err(e) = self
408             .backend_client
409             .lock()
410             .set_features(features)
411             .map_err(Error::SetFeatures)
412         {
413             error!("failed to enable features 0x{:x}: {}", features, e);
414             return;
415         }
416         self.acked_features = features;
417     }
418 
read_config(&self, offset: u64, data: &mut [u8])419     fn read_config(&self, offset: u64, data: &mut [u8]) {
420         if let Some(cfg) = &self.cfg {
421             copy_config(data, 0, cfg, offset);
422             return;
423         }
424 
425         let Ok(offset) = offset.try_into() else {
426             error!("failed to read config: invalid config offset is given: {offset}");
427             return;
428         };
429         let Ok(data_len) = data.len().try_into() else {
430             error!(
431                 "failed to read config: invalid config length is given: {}",
432                 data.len()
433             );
434             return;
435         };
436         let (_, config) = match self.backend_client.lock().get_config(
437             offset,
438             data_len,
439             VhostUserConfigFlags::WRITABLE,
440             data,
441         ) {
442             Ok(x) => x,
443             Err(e) => {
444                 error!("failed to read config: {}", Error::GetConfig(e));
445                 return;
446             }
447         };
448         data.copy_from_slice(&config);
449     }
450 
write_config(&mut self, offset: u64, data: &[u8])451     fn write_config(&mut self, offset: u64, data: &[u8]) {
452         let Ok(offset) = offset.try_into() else {
453             error!("failed to write config: invalid config offset is given: {offset}");
454             return;
455         };
456         if let Err(e) = self
457             .backend_client
458             .lock()
459             .set_config(offset, VhostUserConfigFlags::empty(), data)
460             .map_err(Error::SetConfig)
461         {
462             error!("failed to write config: {}", e);
463         }
464     }
465 
activate( &mut self, mem: GuestMemory, interrupt: Interrupt, queues: BTreeMap<usize, Queue>, ) -> anyhow::Result<()>466     fn activate(
467         &mut self,
468         mem: GuestMemory,
469         interrupt: Interrupt,
470         queues: BTreeMap<usize, Queue>,
471     ) -> anyhow::Result<()> {
472         self.set_mem_table(&mem)?;
473 
474         let msix_config_opt = interrupt
475             .get_msix_config()
476             .as_ref()
477             .ok_or(Error::MsixConfigUnavailable)?;
478         let msix_config = msix_config_opt.lock();
479 
480         let non_msix_evt = Event::new().map_err(Error::CreateEvent)?;
481         for (&queue_index, queue) in queues.iter() {
482             let irqfd = msix_config
483                 .get_irqfd(queue.vector() as usize)
484                 .unwrap_or(&non_msix_evt);
485             self.activate_vring(&mem, queue_index, queue, irqfd)?;
486         }
487 
488         self.sent_queues = Some(queues);
489 
490         drop(msix_config);
491 
492         self.start_worker(interrupt, non_msix_evt);
493         Ok(())
494     }
495 
reset(&mut self) -> anyhow::Result<()>496     fn reset(&mut self) -> anyhow::Result<()> {
497         if let Some(sent_queues) = self.sent_queues.take() {
498             for queue_index in sent_queues.into_keys() {
499                 let _vring_base = self
500                     .deactivate_vring(queue_index)
501                     .context("deactivate_vring failed during reset")?;
502             }
503         }
504 
505         if let Some(w) = self.worker_thread.take() {
506             self.backend_req_handler = w.stop();
507         }
508 
509         Ok(())
510     }
511 
pci_address(&self) -> Option<PciAddress>512     fn pci_address(&self) -> Option<PciAddress> {
513         self.pci_address
514     }
515 
get_shared_memory_region(&self) -> Option<SharedMemoryRegion>516     fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
517         if !self
518             .protocol_features
519             .contains(VhostUserProtocolFeatures::SHARED_MEMORY_REGIONS)
520         {
521             return None;
522         }
523         if let Some(r) = self.shmem_region.borrow().as_ref() {
524             return r.clone();
525         }
526         let regions = match self
527             .backend_client
528             .lock()
529             .get_shared_memory_regions()
530             .map_err(Error::ShmemRegions)
531         {
532             Ok(x) => x,
533             Err(e) => {
534                 error!("Failed to get shared memory regions {}", e);
535                 return None;
536             }
537         };
538         let region = match regions.len() {
539             0 => None,
540             1 => Some(SharedMemoryRegion {
541                 id: regions[0].id,
542                 length: regions[0].length,
543             }),
544             n => {
545                 error!(
546                     "Failed to get shared memory regions {}",
547                     Error::TooManyShmemRegions(n)
548                 );
549                 return None;
550             }
551         };
552 
553         *self.shmem_region.borrow_mut() = Some(region.clone());
554         region
555     }
556 
set_shared_memory_mapper(&mut self, mapper: Box<dyn SharedMemoryMapper>)557     fn set_shared_memory_mapper(&mut self, mapper: Box<dyn SharedMemoryMapper>) {
558         // Return error if backend request handler is not available. This indicates
559         // that `VhostUserProtocolFeatures::BACKEND_REQ` is not negotiated.
560         let Some(backend_req_handler) = self.backend_req_handler.as_mut() else {
561             error!(
562                 "Error setting shared memory mapper {}",
563                 Error::ProtocolFeatureNotNegoiated(VhostUserProtocolFeatures::BACKEND_REQ)
564             );
565             return;
566         };
567 
568         // The virtio framework will only call this if get_shared_memory_region returned a region
569         let shmid = self
570             .shmem_region
571             .borrow()
572             .clone()
573             .flatten()
574             .expect("missing shmid")
575             .id;
576 
577         backend_req_handler
578             .frontend_mut()
579             .set_shared_mapper_state(mapper, shmid);
580     }
581 
expose_shmem_descriptors_with_viommu(&self) -> bool582     fn expose_shmem_descriptors_with_viommu(&self) -> bool {
583         self.expose_shmem_descriptors_with_viommu
584     }
585 
virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>>586     fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
587         let Some(mut queues) = self.sent_queues.take() else {
588             return Ok(None);
589         };
590 
591         for (&queue_index, queue) in queues.iter_mut() {
592             let vring_base = self
593                 .deactivate_vring(queue_index)
594                 .context("deactivate_vring failed during sleep")?;
595             queue.vhost_user_reclaim(vring_base);
596         }
597 
598         if let Some(w) = self.worker_thread.take() {
599             self.backend_req_handler = w.stop();
600         }
601 
602         Ok(Some(queues))
603     }
604 
virtio_wake( &mut self, queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>, ) -> anyhow::Result<()>605     fn virtio_wake(
606         &mut self,
607         // Vhost user doesn't need to pass queue_states back to the device process, since it will
608         // already have it.
609         queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
610     ) -> anyhow::Result<()> {
611         if let Some((mem, interrupt, queues)) = queues_state {
612             self.activate(mem, interrupt, queues)?;
613         }
614         Ok(())
615     }
616 
virtio_snapshot(&mut self) -> anyhow::Result<Value>617     fn virtio_snapshot(&mut self) -> anyhow::Result<Value> {
618         if !self
619             .protocol_features
620             .contains(VhostUserProtocolFeatures::DEVICE_STATE)
621         {
622             bail!("snapshot requires VHOST_USER_PROTOCOL_F_DEVICE_STATE");
623         }
624         let backend_client = self.backend_client.lock();
625         // Send the backend an FD to write the device state to. If it gives us an FD back, then
626         // we need to read from that instead.
627         let (mut r, w) = new_pipe_pair()?;
628         let backend_r = backend_client
629             .set_device_state_fd(
630                 VhostUserTransferDirection::Save,
631                 VhostUserMigrationPhase::Stopped,
632                 &w,
633             )
634             .context("failed to negotiate device state fd")?;
635         // EOF signals end of the device state bytes, so it is important to close our copy of
636         // the write FD before we start reading.
637         std::mem::drop(w);
638         // Read the device state.
639         let mut snapshot_bytes = Vec::new();
640         if let Some(mut backend_r) = backend_r {
641             backend_r.read_to_end(&mut snapshot_bytes)
642         } else {
643             r.read_to_end(&mut snapshot_bytes)
644         }
645         .context("failed to read device state")?;
646         // Call `check_device_state` to ensure the data transfer was successful.
647         backend_client
648             .check_device_state()
649             .context("failed to transfer device state")?;
650         Ok(serde_json::to_value(snapshot_bytes).map_err(Error::SliceToSerdeValue)?)
651     }
652 
virtio_restore(&mut self, data: Value) -> anyhow::Result<()>653     fn virtio_restore(&mut self, data: Value) -> anyhow::Result<()> {
654         if !self
655             .protocol_features
656             .contains(VhostUserProtocolFeatures::DEVICE_STATE)
657         {
658             bail!("restore requires VHOST_USER_PROTOCOL_F_DEVICE_STATE");
659         }
660 
661         let backend_client = self.backend_client.lock();
662         let data_bytes: Vec<u8> = serde_json::from_value(data).map_err(Error::SerdeValueToSlice)?;
663         // Send the backend an FD to read the device state from. If it gives us an FD back,
664         // then we need to write to that instead.
665         let (r, w) = new_pipe_pair()?;
666         let backend_w = backend_client
667             .set_device_state_fd(
668                 VhostUserTransferDirection::Load,
669                 VhostUserMigrationPhase::Stopped,
670                 &r,
671             )
672             .context("failed to negotiate device state fd")?;
673         // Write the device state.
674         {
675             // EOF signals the end of the device state bytes, so we need to ensure the write
676             // objects are dropped before the `check_device_state` call. Done here by moving
677             // them into this scope.
678             let backend_w = backend_w;
679             let mut w = w;
680             if let Some(mut backend_w) = backend_w {
681                 backend_w.write_all(data_bytes.as_slice())
682             } else {
683                 w.write_all(data_bytes.as_slice())
684             }
685             .context("failed to write device state")?;
686         }
687         // Call `check_device_state` to ensure the data transfer was successful.
688         backend_client
689             .check_device_state()
690             .context("failed to transfer device state")?;
691         Ok(())
692     }
693 }
694 
695 #[cfg(unix)]
new_pipe_pair() -> anyhow::Result<(impl AsRawDescriptor + Read, impl AsRawDescriptor + Write)>696 fn new_pipe_pair() -> anyhow::Result<(impl AsRawDescriptor + Read, impl AsRawDescriptor + Write)> {
697     base::pipe().context("failed to create pipe")
698 }
699 
700 #[cfg(windows)]
new_pipe_pair() -> anyhow::Result<(impl AsRawDescriptor + Read, impl AsRawDescriptor + Write)>701 fn new_pipe_pair() -> anyhow::Result<(impl AsRawDescriptor + Read, impl AsRawDescriptor + Write)> {
702     base::named_pipes::pair(
703         &base::named_pipes::FramingMode::Byte,
704         &base::named_pipes::BlockingMode::Wait,
705         /* timeout= */ 0,
706     )
707     .context("failed to create named pipes")
708 }
709