1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 //! VirtioDevice implementation for the VMM side of a vhost-user connection.
6
7 mod error;
8 mod fs;
9 mod handler;
10 mod sys;
11 mod worker;
12
13 use std::cell::RefCell;
14 use std::collections::BTreeMap;
15 use std::io::Read;
16 use std::io::Write;
17 use std::sync::Arc;
18
19 use anyhow::bail;
20 use anyhow::Context;
21 use base::error;
22 use base::trace;
23 use base::AsRawDescriptor;
24 use base::Event;
25 use base::RawDescriptor;
26 use base::WorkerThread;
27 use serde_json::Value;
28 use sync::Mutex;
29 use vm_memory::GuestMemory;
30 use vmm_vhost::message::VhostUserConfigFlags;
31 use vmm_vhost::message::VhostUserMigrationPhase;
32 use vmm_vhost::message::VhostUserProtocolFeatures;
33 use vmm_vhost::message::VhostUserTransferDirection;
34 use vmm_vhost::BackendClient;
35 use vmm_vhost::VhostUserMemoryRegionInfo;
36 use vmm_vhost::VringConfigData;
37 use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
38
39 use crate::virtio::copy_config;
40 use crate::virtio::device_constants::VIRTIO_DEVICE_TYPE_SPECIFIC_FEATURES_MASK;
41 use crate::virtio::vhost_user_frontend::error::Error;
42 use crate::virtio::vhost_user_frontend::error::Result;
43 use crate::virtio::vhost_user_frontend::handler::BackendReqHandler;
44 use crate::virtio::vhost_user_frontend::handler::BackendReqHandlerImpl;
45 use crate::virtio::vhost_user_frontend::sys::create_backend_req_handler;
46 use crate::virtio::vhost_user_frontend::worker::Worker;
47 use crate::virtio::DeviceType;
48 use crate::virtio::Interrupt;
49 use crate::virtio::Queue;
50 use crate::virtio::SharedMemoryMapper;
51 use crate::virtio::SharedMemoryRegion;
52 use crate::virtio::VirtioDevice;
53 use crate::PciAddress;
54
55 pub struct VhostUserFrontend {
56 device_type: DeviceType,
57 worker_thread: Option<WorkerThread<Option<BackendReqHandler>>>,
58
59 backend_client: Arc<Mutex<BackendClient>>,
60 avail_features: u64,
61 acked_features: u64,
62 protocol_features: VhostUserProtocolFeatures,
63 // `backend_req_handler` is only present if the backend supports BACKEND_REQ. `worker_thread`
64 // takes ownership of `backend_req_handler` when it starts. The worker thread will always
65 // return ownershp of the handler when stopped.
66 backend_req_handler: Option<BackendReqHandler>,
67 // Shared memory region info. IPC result from backend is saved with outer Option.
68 shmem_region: RefCell<Option<Option<SharedMemoryRegion>>>,
69
70 queue_sizes: Vec<u16>,
71 cfg: Option<Vec<u8>>,
72 expose_shmem_descriptors_with_viommu: bool,
73 pci_address: Option<PciAddress>,
74
75 // Queues that have been sent to the backend. Always `Some` when active and not asleep. Saved
76 // for use in `virtio_sleep`. Since the backend is managing them, the local state of the queue
77 // is likely stale.
78 sent_queues: Option<BTreeMap<usize, Queue>>,
79 }
80
81 // Returns the largest power of two that is less than or equal to `val`.
power_of_two_le(val: u16) -> Option<u16>82 fn power_of_two_le(val: u16) -> Option<u16> {
83 if val == 0 {
84 None
85 } else if val.is_power_of_two() {
86 Some(val)
87 } else {
88 val.checked_next_power_of_two()
89 .map(|next_pow_two| next_pow_two / 2)
90 }
91 }
92
93 impl VhostUserFrontend {
94 /// Create a new VirtioDevice for a vhost-user device frontend.
95 ///
96 /// # Arguments
97 ///
98 /// - `device_type`: virtio device type
99 /// - `base_features`: base virtio device features (e.g. `VIRTIO_F_VERSION_1`)
100 /// - `connection`: connection to the device backend
101 /// - `max_queue_size`: maximum number of entries in each queue (default: [`Queue::MAX_SIZE`])
new( device_type: DeviceType, base_features: u64, connection: vmm_vhost::Connection<vmm_vhost::FrontendReq>, max_queue_size: Option<u16>, pci_address: Option<PciAddress>, ) -> Result<VhostUserFrontend>102 pub fn new(
103 device_type: DeviceType,
104 base_features: u64,
105 connection: vmm_vhost::Connection<vmm_vhost::FrontendReq>,
106 max_queue_size: Option<u16>,
107 pci_address: Option<PciAddress>,
108 ) -> Result<VhostUserFrontend> {
109 VhostUserFrontend::new_internal(
110 connection,
111 device_type,
112 max_queue_size,
113 base_features,
114 None, // cfg
115 pci_address,
116 )
117 }
118
119 /// Create a new VirtioDevice for a vhost-user device frontend.
120 ///
121 /// # Arguments
122 ///
123 /// - `connection`: connection to the device backend
124 /// - `device_type`: virtio device type
125 /// - `max_queue_size`: maximum number of entries in each queue (default: [`Queue::MAX_SIZE`])
126 /// - `base_features`: base virtio device features (e.g. `VIRTIO_F_VERSION_1`)
127 /// - `cfg`: bytes to return for the virtio configuration space (queried from device if not
128 /// specified)
new_internal( connection: vmm_vhost::Connection<vmm_vhost::FrontendReq>, device_type: DeviceType, max_queue_size: Option<u16>, mut base_features: u64, cfg: Option<&[u8]>, pci_address: Option<PciAddress>, ) -> Result<VhostUserFrontend>129 pub(crate) fn new_internal(
130 connection: vmm_vhost::Connection<vmm_vhost::FrontendReq>,
131 device_type: DeviceType,
132 max_queue_size: Option<u16>,
133 mut base_features: u64,
134 cfg: Option<&[u8]>,
135 pci_address: Option<PciAddress>,
136 ) -> Result<VhostUserFrontend> {
137 // Don't allow packed queues even if requested. We don't handle them properly yet at the
138 // protocol layer.
139 // TODO: b/331466964 - Remove once packed queue support is added to BackendClient.
140 if base_features & (1 << virtio_sys::virtio_config::VIRTIO_F_RING_PACKED) != 0 {
141 base_features &= !(1 << virtio_sys::virtio_config::VIRTIO_F_RING_PACKED);
142 base::warn!(
143 "VIRTIO_F_RING_PACKED requested, but not yet supported by vhost-user frontend. \
144 Automatically disabled."
145 );
146 }
147
148 #[cfg(windows)]
149 let backend_pid = connection.target_pid();
150
151 let mut backend_client = BackendClient::new(connection);
152
153 backend_client.set_owner().map_err(Error::SetOwner)?;
154
155 let allow_features = VIRTIO_DEVICE_TYPE_SPECIFIC_FEATURES_MASK
156 | base_features
157 | 1 << VHOST_USER_F_PROTOCOL_FEATURES;
158 let avail_features =
159 allow_features & backend_client.get_features().map_err(Error::GetFeatures)?;
160 let mut acked_features = 0;
161
162 let mut allow_protocol_features = VhostUserProtocolFeatures::CONFIG
163 | VhostUserProtocolFeatures::MQ
164 | VhostUserProtocolFeatures::BACKEND_REQ
165 | VhostUserProtocolFeatures::DEVICE_STATE;
166
167 // HACK: the crosvm vhost-user GPU backend supports the non-standard
168 // VHOST_USER_PROTOCOL_FEATURE_SHARED_MEMORY_REGIONS. This should either be standardized
169 // (and enabled for all device types) or removed.
170 let expose_shmem_descriptors_with_viommu = if device_type == DeviceType::Gpu {
171 allow_protocol_features |= VhostUserProtocolFeatures::SHARED_MEMORY_REGIONS;
172 true
173 } else {
174 false
175 };
176
177 let mut protocol_features = VhostUserProtocolFeatures::empty();
178 if avail_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 {
179 // The vhost-user backend supports VHOST_USER_F_PROTOCOL_FEATURES; enable it.
180 backend_client
181 .set_features(1 << VHOST_USER_F_PROTOCOL_FEATURES)
182 .map_err(Error::SetFeatures)?;
183 acked_features |= 1 << VHOST_USER_F_PROTOCOL_FEATURES;
184
185 let avail_protocol_features = backend_client
186 .get_protocol_features()
187 .map_err(Error::GetProtocolFeatures)?;
188 protocol_features = allow_protocol_features & avail_protocol_features;
189 backend_client
190 .set_protocol_features(protocol_features)
191 .map_err(Error::SetProtocolFeatures)?;
192 }
193
194 // if protocol feature `VhostUserProtocolFeatures::BACKEND_REQ` is negotiated.
195 let backend_req_handler =
196 if protocol_features.contains(VhostUserProtocolFeatures::BACKEND_REQ) {
197 let (handler, tx_fd) = create_backend_req_handler(
198 BackendReqHandlerImpl::new(),
199 #[cfg(windows)]
200 backend_pid,
201 )?;
202 backend_client
203 .set_backend_req_fd(&tx_fd)
204 .map_err(Error::SetDeviceRequestChannel)?;
205 Some(handler)
206 } else {
207 None
208 };
209
210 // If the device supports VHOST_USER_PROTOCOL_F_MQ, use VHOST_USER_GET_QUEUE_NUM to
211 // determine the number of queues supported. Otherwise, use the minimum number of queues
212 // required by the spec for this device type.
213 let num_queues = if protocol_features.contains(VhostUserProtocolFeatures::MQ) {
214 trace!("backend supports VHOST_USER_PROTOCOL_F_MQ");
215 let num_queues = backend_client.get_queue_num().map_err(Error::GetQueueNum)?;
216 trace!("VHOST_USER_GET_QUEUE_NUM returned {num_queues}");
217 num_queues as usize
218 } else {
219 trace!("backend does not support VHOST_USER_PROTOCOL_F_MQ");
220 device_type.min_queues()
221 };
222
223 // Clamp the maximum queue size to the largest power of 2 <= max_queue_size.
224 let max_queue_size = max_queue_size
225 .and_then(power_of_two_le)
226 .unwrap_or(Queue::MAX_SIZE);
227
228 trace!(
229 "vhost-user {device_type} frontend with {num_queues} queues x {max_queue_size} entries\
230 {}",
231 if let Some(pci_address) = pci_address {
232 format!(" pci-address {pci_address}")
233 } else {
234 "".to_string()
235 }
236 );
237
238 let queue_sizes = vec![max_queue_size; num_queues];
239
240 Ok(VhostUserFrontend {
241 device_type,
242 worker_thread: None,
243 backend_client: Arc::new(Mutex::new(backend_client)),
244 avail_features,
245 acked_features,
246 protocol_features,
247 backend_req_handler,
248 shmem_region: RefCell::new(None),
249 queue_sizes,
250 cfg: cfg.map(|cfg| cfg.to_vec()),
251 expose_shmem_descriptors_with_viommu,
252 pci_address,
253 sent_queues: None,
254 })
255 }
256
set_mem_table(&mut self, mem: &GuestMemory) -> Result<()>257 fn set_mem_table(&mut self, mem: &GuestMemory) -> Result<()> {
258 let regions: Vec<_> = mem
259 .regions()
260 .map(|region| VhostUserMemoryRegionInfo {
261 guest_phys_addr: region.guest_addr.0,
262 memory_size: region.size as u64,
263 userspace_addr: region.host_addr as u64,
264 mmap_offset: region.shm_offset,
265 mmap_handle: region.shm.as_raw_descriptor(),
266 })
267 .collect();
268
269 self.backend_client
270 .lock()
271 .set_mem_table(regions.as_slice())
272 .map_err(Error::SetMemTable)?;
273
274 Ok(())
275 }
276
277 /// Activates a vring for the given `queue`.
activate_vring( &mut self, mem: &GuestMemory, queue_index: usize, queue: &Queue, irqfd: &Event, ) -> Result<()>278 fn activate_vring(
279 &mut self,
280 mem: &GuestMemory,
281 queue_index: usize,
282 queue: &Queue,
283 irqfd: &Event,
284 ) -> Result<()> {
285 let backend_client = self.backend_client.lock();
286 backend_client
287 .set_vring_num(queue_index, queue.size())
288 .map_err(Error::SetVringNum)?;
289
290 let config_data = VringConfigData {
291 queue_size: queue.size(),
292 flags: 0u32,
293 desc_table_addr: mem
294 .get_host_address(queue.desc_table())
295 .map_err(Error::GetHostAddress)? as u64,
296 used_ring_addr: mem
297 .get_host_address(queue.used_ring())
298 .map_err(Error::GetHostAddress)? as u64,
299 avail_ring_addr: mem
300 .get_host_address(queue.avail_ring())
301 .map_err(Error::GetHostAddress)? as u64,
302 log_addr: None,
303 };
304 backend_client
305 .set_vring_addr(queue_index, &config_data)
306 .map_err(Error::SetVringAddr)?;
307
308 backend_client
309 .set_vring_base(queue_index, queue.next_avail_to_process())
310 .map_err(Error::SetVringBase)?;
311
312 backend_client
313 .set_vring_call(queue_index, irqfd)
314 .map_err(Error::SetVringCall)?;
315 backend_client
316 .set_vring_kick(queue_index, queue.event())
317 .map_err(Error::SetVringKick)?;
318
319 // Per protocol documentation, `VHOST_USER_SET_VRING_ENABLE` should be sent only when
320 // `VHOST_USER_F_PROTOCOL_FEATURES` has been negotiated.
321 if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 {
322 backend_client
323 .set_vring_enable(queue_index, true)
324 .map_err(Error::SetVringEnable)?;
325 }
326
327 Ok(())
328 }
329
330 /// Stops the vring for the given `queue`, returning its base index.
deactivate_vring(&self, queue_index: usize) -> Result<u16>331 fn deactivate_vring(&self, queue_index: usize) -> Result<u16> {
332 let backend_client = self.backend_client.lock();
333
334 if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 {
335 backend_client
336 .set_vring_enable(queue_index, false)
337 .map_err(Error::SetVringEnable)?;
338 }
339
340 let vring_base = backend_client
341 .get_vring_base(queue_index)
342 .map_err(Error::GetVringBase)?;
343
344 vring_base
345 .try_into()
346 .map_err(|_| Error::VringBaseTooBig(vring_base))
347 }
348
349 /// Helper to start up the worker thread that will be used with handling interrupts and requests
350 /// from the device process.
start_worker(&mut self, interrupt: Interrupt, non_msix_evt: Event)351 fn start_worker(&mut self, interrupt: Interrupt, non_msix_evt: Event) {
352 assert!(
353 self.worker_thread.is_none(),
354 "BUG: attempted to start worker twice"
355 );
356
357 let label = self.debug_label();
358
359 let mut backend_req_handler = self.backend_req_handler.take();
360 if let Some(handler) = &mut backend_req_handler {
361 // Using unwrap here to get the mutex protected value
362 handler.frontend_mut().set_interrupt(interrupt.clone());
363 }
364
365 let backend_client = self.backend_client.clone();
366
367 self.worker_thread = Some(WorkerThread::start(label.clone(), move |kill_evt| {
368 let mut worker = Worker {
369 kill_evt,
370 non_msix_evt,
371 backend_req_handler,
372 backend_client,
373 };
374 worker
375 .run(interrupt)
376 .with_context(|| format!("{label}: vhost_user_frontend worker failed"))
377 .unwrap();
378 worker.backend_req_handler
379 }));
380 }
381 }
382
383 impl VirtioDevice for VhostUserFrontend {
384 // Override the default debug label to differentiate vhost-user devices from virtio.
debug_label(&self) -> String385 fn debug_label(&self) -> String {
386 format!("vu-{}", self.device_type())
387 }
388
keep_rds(&self) -> Vec<RawDescriptor>389 fn keep_rds(&self) -> Vec<RawDescriptor> {
390 Vec::new()
391 }
392
device_type(&self) -> DeviceType393 fn device_type(&self) -> DeviceType {
394 self.device_type
395 }
396
queue_max_sizes(&self) -> &[u16]397 fn queue_max_sizes(&self) -> &[u16] {
398 &self.queue_sizes
399 }
400
features(&self) -> u64401 fn features(&self) -> u64 {
402 self.avail_features
403 }
404
ack_features(&mut self, features: u64)405 fn ack_features(&mut self, features: u64) {
406 let features = (features & self.avail_features) | self.acked_features;
407 if let Err(e) = self
408 .backend_client
409 .lock()
410 .set_features(features)
411 .map_err(Error::SetFeatures)
412 {
413 error!("failed to enable features 0x{:x}: {}", features, e);
414 return;
415 }
416 self.acked_features = features;
417 }
418
read_config(&self, offset: u64, data: &mut [u8])419 fn read_config(&self, offset: u64, data: &mut [u8]) {
420 if let Some(cfg) = &self.cfg {
421 copy_config(data, 0, cfg, offset);
422 return;
423 }
424
425 let Ok(offset) = offset.try_into() else {
426 error!("failed to read config: invalid config offset is given: {offset}");
427 return;
428 };
429 let Ok(data_len) = data.len().try_into() else {
430 error!(
431 "failed to read config: invalid config length is given: {}",
432 data.len()
433 );
434 return;
435 };
436 let (_, config) = match self.backend_client.lock().get_config(
437 offset,
438 data_len,
439 VhostUserConfigFlags::WRITABLE,
440 data,
441 ) {
442 Ok(x) => x,
443 Err(e) => {
444 error!("failed to read config: {}", Error::GetConfig(e));
445 return;
446 }
447 };
448 data.copy_from_slice(&config);
449 }
450
write_config(&mut self, offset: u64, data: &[u8])451 fn write_config(&mut self, offset: u64, data: &[u8]) {
452 let Ok(offset) = offset.try_into() else {
453 error!("failed to write config: invalid config offset is given: {offset}");
454 return;
455 };
456 if let Err(e) = self
457 .backend_client
458 .lock()
459 .set_config(offset, VhostUserConfigFlags::empty(), data)
460 .map_err(Error::SetConfig)
461 {
462 error!("failed to write config: {}", e);
463 }
464 }
465
activate( &mut self, mem: GuestMemory, interrupt: Interrupt, queues: BTreeMap<usize, Queue>, ) -> anyhow::Result<()>466 fn activate(
467 &mut self,
468 mem: GuestMemory,
469 interrupt: Interrupt,
470 queues: BTreeMap<usize, Queue>,
471 ) -> anyhow::Result<()> {
472 self.set_mem_table(&mem)?;
473
474 let msix_config_opt = interrupt
475 .get_msix_config()
476 .as_ref()
477 .ok_or(Error::MsixConfigUnavailable)?;
478 let msix_config = msix_config_opt.lock();
479
480 let non_msix_evt = Event::new().map_err(Error::CreateEvent)?;
481 for (&queue_index, queue) in queues.iter() {
482 let irqfd = msix_config
483 .get_irqfd(queue.vector() as usize)
484 .unwrap_or(&non_msix_evt);
485 self.activate_vring(&mem, queue_index, queue, irqfd)?;
486 }
487
488 self.sent_queues = Some(queues);
489
490 drop(msix_config);
491
492 self.start_worker(interrupt, non_msix_evt);
493 Ok(())
494 }
495
reset(&mut self) -> anyhow::Result<()>496 fn reset(&mut self) -> anyhow::Result<()> {
497 if let Some(sent_queues) = self.sent_queues.take() {
498 for queue_index in sent_queues.into_keys() {
499 let _vring_base = self
500 .deactivate_vring(queue_index)
501 .context("deactivate_vring failed during reset")?;
502 }
503 }
504
505 if let Some(w) = self.worker_thread.take() {
506 self.backend_req_handler = w.stop();
507 }
508
509 Ok(())
510 }
511
pci_address(&self) -> Option<PciAddress>512 fn pci_address(&self) -> Option<PciAddress> {
513 self.pci_address
514 }
515
get_shared_memory_region(&self) -> Option<SharedMemoryRegion>516 fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
517 if !self
518 .protocol_features
519 .contains(VhostUserProtocolFeatures::SHARED_MEMORY_REGIONS)
520 {
521 return None;
522 }
523 if let Some(r) = self.shmem_region.borrow().as_ref() {
524 return r.clone();
525 }
526 let regions = match self
527 .backend_client
528 .lock()
529 .get_shared_memory_regions()
530 .map_err(Error::ShmemRegions)
531 {
532 Ok(x) => x,
533 Err(e) => {
534 error!("Failed to get shared memory regions {}", e);
535 return None;
536 }
537 };
538 let region = match regions.len() {
539 0 => None,
540 1 => Some(SharedMemoryRegion {
541 id: regions[0].id,
542 length: regions[0].length,
543 }),
544 n => {
545 error!(
546 "Failed to get shared memory regions {}",
547 Error::TooManyShmemRegions(n)
548 );
549 return None;
550 }
551 };
552
553 *self.shmem_region.borrow_mut() = Some(region.clone());
554 region
555 }
556
set_shared_memory_mapper(&mut self, mapper: Box<dyn SharedMemoryMapper>)557 fn set_shared_memory_mapper(&mut self, mapper: Box<dyn SharedMemoryMapper>) {
558 // Return error if backend request handler is not available. This indicates
559 // that `VhostUserProtocolFeatures::BACKEND_REQ` is not negotiated.
560 let Some(backend_req_handler) = self.backend_req_handler.as_mut() else {
561 error!(
562 "Error setting shared memory mapper {}",
563 Error::ProtocolFeatureNotNegoiated(VhostUserProtocolFeatures::BACKEND_REQ)
564 );
565 return;
566 };
567
568 // The virtio framework will only call this if get_shared_memory_region returned a region
569 let shmid = self
570 .shmem_region
571 .borrow()
572 .clone()
573 .flatten()
574 .expect("missing shmid")
575 .id;
576
577 backend_req_handler
578 .frontend_mut()
579 .set_shared_mapper_state(mapper, shmid);
580 }
581
expose_shmem_descriptors_with_viommu(&self) -> bool582 fn expose_shmem_descriptors_with_viommu(&self) -> bool {
583 self.expose_shmem_descriptors_with_viommu
584 }
585
virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>>586 fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
587 let Some(mut queues) = self.sent_queues.take() else {
588 return Ok(None);
589 };
590
591 for (&queue_index, queue) in queues.iter_mut() {
592 let vring_base = self
593 .deactivate_vring(queue_index)
594 .context("deactivate_vring failed during sleep")?;
595 queue.vhost_user_reclaim(vring_base);
596 }
597
598 if let Some(w) = self.worker_thread.take() {
599 self.backend_req_handler = w.stop();
600 }
601
602 Ok(Some(queues))
603 }
604
virtio_wake( &mut self, queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>, ) -> anyhow::Result<()>605 fn virtio_wake(
606 &mut self,
607 // Vhost user doesn't need to pass queue_states back to the device process, since it will
608 // already have it.
609 queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
610 ) -> anyhow::Result<()> {
611 if let Some((mem, interrupt, queues)) = queues_state {
612 self.activate(mem, interrupt, queues)?;
613 }
614 Ok(())
615 }
616
virtio_snapshot(&mut self) -> anyhow::Result<Value>617 fn virtio_snapshot(&mut self) -> anyhow::Result<Value> {
618 if !self
619 .protocol_features
620 .contains(VhostUserProtocolFeatures::DEVICE_STATE)
621 {
622 bail!("snapshot requires VHOST_USER_PROTOCOL_F_DEVICE_STATE");
623 }
624 let backend_client = self.backend_client.lock();
625 // Send the backend an FD to write the device state to. If it gives us an FD back, then
626 // we need to read from that instead.
627 let (mut r, w) = new_pipe_pair()?;
628 let backend_r = backend_client
629 .set_device_state_fd(
630 VhostUserTransferDirection::Save,
631 VhostUserMigrationPhase::Stopped,
632 &w,
633 )
634 .context("failed to negotiate device state fd")?;
635 // EOF signals end of the device state bytes, so it is important to close our copy of
636 // the write FD before we start reading.
637 std::mem::drop(w);
638 // Read the device state.
639 let mut snapshot_bytes = Vec::new();
640 if let Some(mut backend_r) = backend_r {
641 backend_r.read_to_end(&mut snapshot_bytes)
642 } else {
643 r.read_to_end(&mut snapshot_bytes)
644 }
645 .context("failed to read device state")?;
646 // Call `check_device_state` to ensure the data transfer was successful.
647 backend_client
648 .check_device_state()
649 .context("failed to transfer device state")?;
650 Ok(serde_json::to_value(snapshot_bytes).map_err(Error::SliceToSerdeValue)?)
651 }
652
virtio_restore(&mut self, data: Value) -> anyhow::Result<()>653 fn virtio_restore(&mut self, data: Value) -> anyhow::Result<()> {
654 if !self
655 .protocol_features
656 .contains(VhostUserProtocolFeatures::DEVICE_STATE)
657 {
658 bail!("restore requires VHOST_USER_PROTOCOL_F_DEVICE_STATE");
659 }
660
661 let backend_client = self.backend_client.lock();
662 let data_bytes: Vec<u8> = serde_json::from_value(data).map_err(Error::SerdeValueToSlice)?;
663 // Send the backend an FD to read the device state from. If it gives us an FD back,
664 // then we need to write to that instead.
665 let (r, w) = new_pipe_pair()?;
666 let backend_w = backend_client
667 .set_device_state_fd(
668 VhostUserTransferDirection::Load,
669 VhostUserMigrationPhase::Stopped,
670 &r,
671 )
672 .context("failed to negotiate device state fd")?;
673 // Write the device state.
674 {
675 // EOF signals the end of the device state bytes, so we need to ensure the write
676 // objects are dropped before the `check_device_state` call. Done here by moving
677 // them into this scope.
678 let backend_w = backend_w;
679 let mut w = w;
680 if let Some(mut backend_w) = backend_w {
681 backend_w.write_all(data_bytes.as_slice())
682 } else {
683 w.write_all(data_bytes.as_slice())
684 }
685 .context("failed to write device state")?;
686 }
687 // Call `check_device_state` to ensure the data transfer was successful.
688 backend_client
689 .check_device_state()
690 .context("failed to transfer device state")?;
691 Ok(())
692 }
693 }
694
695 #[cfg(unix)]
new_pipe_pair() -> anyhow::Result<(impl AsRawDescriptor + Read, impl AsRawDescriptor + Write)>696 fn new_pipe_pair() -> anyhow::Result<(impl AsRawDescriptor + Read, impl AsRawDescriptor + Write)> {
697 base::pipe().context("failed to create pipe")
698 }
699
700 #[cfg(windows)]
new_pipe_pair() -> anyhow::Result<(impl AsRawDescriptor + Read, impl AsRawDescriptor + Write)>701 fn new_pipe_pair() -> anyhow::Result<(impl AsRawDescriptor + Read, impl AsRawDescriptor + Write)> {
702 base::named_pipes::pair(
703 &base::named_pipes::FramingMode::Byte,
704 &base::named_pipes::BlockingMode::Wait,
705 /* timeout= */ 0,
706 )
707 .context("failed to create named pipes")
708 }
709