xref: /aosp_15_r20/external/crosvm/devices/src/virtio/balloon.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2017 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 mod sys;
6 
7 use std::collections::BTreeMap;
8 use std::collections::VecDeque;
9 use std::io::Write;
10 use std::sync::Arc;
11 
12 use anyhow::anyhow;
13 use anyhow::Context;
14 use balloon_control::BalloonStats;
15 use balloon_control::BalloonTubeCommand;
16 use balloon_control::BalloonTubeResult;
17 use balloon_control::BalloonWS;
18 use balloon_control::WSBucket;
19 use balloon_control::VIRTIO_BALLOON_WS_MAX_NUM_BINS;
20 use balloon_control::VIRTIO_BALLOON_WS_MIN_NUM_BINS;
21 use base::debug;
22 use base::error;
23 use base::warn;
24 use base::AsRawDescriptor;
25 use base::Event;
26 use base::RawDescriptor;
27 #[cfg(feature = "registered_events")]
28 use base::SendTube;
29 use base::Tube;
30 use base::WorkerThread;
31 use cros_async::block_on;
32 use cros_async::sync::RwLock as AsyncRwLock;
33 use cros_async::AsyncTube;
34 use cros_async::EventAsync;
35 use cros_async::Executor;
36 #[cfg(feature = "registered_events")]
37 use cros_async::SendTubeAsync;
38 use data_model::Le16;
39 use data_model::Le32;
40 use data_model::Le64;
41 use futures::channel::mpsc;
42 use futures::channel::oneshot;
43 use futures::pin_mut;
44 use futures::select;
45 use futures::select_biased;
46 use futures::FutureExt;
47 use futures::StreamExt;
48 use remain::sorted;
49 use serde::Deserialize;
50 use serde::Serialize;
51 use thiserror::Error as ThisError;
52 #[cfg(windows)]
53 use vm_control::api::VmMemoryClient;
54 #[cfg(feature = "registered_events")]
55 use vm_control::RegisteredEventWithData;
56 use vm_memory::GuestAddress;
57 use vm_memory::GuestMemory;
58 use zerocopy::AsBytes;
59 use zerocopy::FromBytes;
60 use zerocopy::FromZeroes;
61 
62 use super::async_utils;
63 use super::copy_config;
64 use super::create_stop_oneshot;
65 use super::DescriptorChain;
66 use super::DeviceType;
67 use super::Interrupt;
68 use super::Queue;
69 use super::Reader;
70 use super::StoppedWorker;
71 use super::VirtioDevice;
72 use crate::UnpinRequest;
73 use crate::UnpinResponse;
74 
75 #[sorted]
76 #[derive(ThisError, Debug)]
77 pub enum BalloonError {
78     /// Failed an async await
79     #[error("failed async await: {0}")]
80     AsyncAwait(cros_async::AsyncError),
81     /// Failed an async await
82     #[error("failed async await: {0}")]
83     AsyncAwaitAnyhow(anyhow::Error),
84     /// Failed to create event.
85     #[error("failed to create event: {0}")]
86     CreatingEvent(base::Error),
87     /// Failed to create async message receiver.
88     #[error("failed to create async message receiver: {0}")]
89     CreatingMessageReceiver(base::TubeError),
90     /// Failed to receive command message.
91     #[error("failed to receive command message: {0}")]
92     ReceivingCommand(base::TubeError),
93     /// Failed to send command response.
94     #[error("failed to send command response: {0}")]
95     SendResponse(base::TubeError),
96     /// Error while writing to virtqueue
97     #[error("failed to write to virtqueue: {0}")]
98     WriteQueue(std::io::Error),
99     /// Failed to write config event.
100     #[error("failed to write config event: {0}")]
101     WritingConfigEvent(base::Error),
102 }
103 pub type Result<T> = std::result::Result<T, BalloonError>;
104 
105 // Balloon implements five virt IO queues: Inflate, Deflate, Stats, WsData, WsCmd.
106 const QUEUE_SIZE: u16 = 128;
107 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE];
108 
109 // Virtqueue indexes
110 const INFLATEQ: usize = 0;
111 const DEFLATEQ: usize = 1;
112 const STATSQ: usize = 2;
113 const _FREE_PAGE_VQ: usize = 3;
114 const REPORTING_VQ: usize = 4;
115 const WS_DATA_VQ: usize = 5;
116 const WS_OP_VQ: usize = 6;
117 
118 const VIRTIO_BALLOON_PFN_SHIFT: u32 = 12;
119 const VIRTIO_BALLOON_PF_SIZE: u64 = 1 << VIRTIO_BALLOON_PFN_SHIFT;
120 
121 // The feature bitmap for virtio balloon
122 const VIRTIO_BALLOON_F_MUST_TELL_HOST: u32 = 0; // Tell before reclaiming pages
123 const VIRTIO_BALLOON_F_STATS_VQ: u32 = 1; // Stats reporting enabled
124 const VIRTIO_BALLOON_F_DEFLATE_ON_OOM: u32 = 2; // Deflate balloon on OOM
125 const VIRTIO_BALLOON_F_PAGE_REPORTING: u32 = 5; // Page reporting virtqueue
126                                                 // TODO(b/273973298): this should maybe be bit 6? to be changed later
127 const VIRTIO_BALLOON_F_WS_REPORTING: u32 = 8; // Working Set Reporting virtqueues
128 
129 #[derive(Copy, Clone)]
130 #[repr(u32)]
131 // Balloon virtqueues
132 pub enum BalloonFeatures {
133     // Page Reporting enabled
134     PageReporting = VIRTIO_BALLOON_F_PAGE_REPORTING,
135     // WS Reporting enabled
136     WSReporting = VIRTIO_BALLOON_F_WS_REPORTING,
137 }
138 
139 // virtio_balloon_config is the balloon device configuration space defined by the virtio spec.
140 #[derive(Copy, Clone, Debug, Default, AsBytes, FromZeroes, FromBytes)]
141 #[repr(C)]
142 struct virtio_balloon_config {
143     num_pages: Le32,
144     actual: Le32,
145     free_page_hint_cmd_id: Le32,
146     poison_val: Le32,
147     // WS field is part of proposed spec extension (b/273973298).
148     ws_num_bins: u8,
149     _reserved: [u8; 3],
150 }
151 
152 // BalloonState is shared by the worker and device thread.
153 #[derive(Clone, Default, Serialize, Deserialize)]
154 struct BalloonState {
155     num_pages: u32,
156     actual_pages: u32,
157     expecting_ws: bool,
158     // Flag indicating that the balloon is in the process of a failable update. This
159     // is set by an Adjust command that has allow_failure set, and is cleared when the
160     // Adjusted success/failure response is sent.
161     failable_update: bool,
162     pending_adjusted_responses: VecDeque<u32>,
163 }
164 
165 // The constants defining stats types in virtio_baloon_stat
166 const VIRTIO_BALLOON_S_SWAP_IN: u16 = 0;
167 const VIRTIO_BALLOON_S_SWAP_OUT: u16 = 1;
168 const VIRTIO_BALLOON_S_MAJFLT: u16 = 2;
169 const VIRTIO_BALLOON_S_MINFLT: u16 = 3;
170 const VIRTIO_BALLOON_S_MEMFREE: u16 = 4;
171 const VIRTIO_BALLOON_S_MEMTOT: u16 = 5;
172 const VIRTIO_BALLOON_S_AVAIL: u16 = 6;
173 const VIRTIO_BALLOON_S_CACHES: u16 = 7;
174 const VIRTIO_BALLOON_S_HTLB_PGALLOC: u16 = 8;
175 const VIRTIO_BALLOON_S_HTLB_PGFAIL: u16 = 9;
176 const VIRTIO_BALLOON_S_NONSTANDARD_SHMEM: u16 = 65534;
177 const VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE: u16 = 65535;
178 
179 // BalloonStat is used to deserialize stats from the stats_queue.
180 #[derive(Copy, Clone, FromZeroes, FromBytes, AsBytes)]
181 #[repr(C, packed)]
182 struct BalloonStat {
183     tag: Le16,
184     val: Le64,
185 }
186 
187 impl BalloonStat {
update_stats(&self, stats: &mut BalloonStats)188     fn update_stats(&self, stats: &mut BalloonStats) {
189         let val = Some(self.val.to_native());
190         match self.tag.to_native() {
191             VIRTIO_BALLOON_S_SWAP_IN => stats.swap_in = val,
192             VIRTIO_BALLOON_S_SWAP_OUT => stats.swap_out = val,
193             VIRTIO_BALLOON_S_MAJFLT => stats.major_faults = val,
194             VIRTIO_BALLOON_S_MINFLT => stats.minor_faults = val,
195             VIRTIO_BALLOON_S_MEMFREE => stats.free_memory = val,
196             VIRTIO_BALLOON_S_MEMTOT => stats.total_memory = val,
197             VIRTIO_BALLOON_S_AVAIL => stats.available_memory = val,
198             VIRTIO_BALLOON_S_CACHES => stats.disk_caches = val,
199             VIRTIO_BALLOON_S_HTLB_PGALLOC => stats.hugetlb_allocations = val,
200             VIRTIO_BALLOON_S_HTLB_PGFAIL => stats.hugetlb_failures = val,
201             VIRTIO_BALLOON_S_NONSTANDARD_SHMEM => stats.shared_memory = val,
202             VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE => stats.unevictable_memory = val,
203             _ => (),
204         }
205     }
206 }
207 
208 // virtio_balloon_ws is used to deserialize from the ws data vq.
209 #[repr(C)]
210 #[derive(Copy, Clone, Debug, Default, AsBytes, FromZeroes, FromBytes)]
211 struct virtio_balloon_ws {
212     tag: Le16,
213     node_id: Le16,
214     // virtio prefers field members to align on a word boundary so we must pad. see:
215     // https://crsrc.org/o/src/third_party/kernel/v5.15/include/uapi/linux/virtio_balloon.h;l=105
216     _reserved: [u8; 4],
217     idle_age_ms: Le64,
218     // TODO(b/273973298): these should become separate fields - bytes for ANON and FILE
219     memory_size_bytes: [Le64; 2],
220 }
221 
222 impl virtio_balloon_ws {
update_ws(&self, ws: &mut BalloonWS)223     fn update_ws(&self, ws: &mut BalloonWS) {
224         let bucket = WSBucket {
225             age: self.idle_age_ms.to_native(),
226             bytes: [
227                 self.memory_size_bytes[0].to_native(),
228                 self.memory_size_bytes[1].to_native(),
229             ],
230         };
231         ws.ws.push(bucket);
232     }
233 }
234 
235 const _VIRTIO_BALLOON_WS_OP_INVALID: u16 = 0;
236 const VIRTIO_BALLOON_WS_OP_REQUEST: u16 = 1;
237 const VIRTIO_BALLOON_WS_OP_CONFIG: u16 = 2;
238 const _VIRTIO_BALLOON_WS_OP_DISCARD: u16 = 3;
239 
240 // virtio_balloon_op is used to serialize to the ws cmd vq.
241 #[repr(C, packed)]
242 #[derive(Copy, Clone, Debug, Default, AsBytes, FromZeroes, FromBytes)]
243 struct virtio_balloon_op {
244     type_: Le16,
245 }
246 
invoke_desc_handler<F>(ranges: Vec<(u64, u64)>, desc_handler: &mut F) where F: FnMut(GuestAddress, u64),247 fn invoke_desc_handler<F>(ranges: Vec<(u64, u64)>, desc_handler: &mut F)
248 where
249     F: FnMut(GuestAddress, u64),
250 {
251     for range in ranges {
252         desc_handler(GuestAddress(range.0), range.1);
253     }
254 }
255 
256 // Release a list of guest memory ranges back to the host system.
257 // Unpin requests for each inflate range will be sent via `release_memory_tube`
258 // if provided, and then `desc_handler` will be called for each inflate range.
release_ranges<F>( release_memory_tube: Option<&Tube>, inflate_ranges: Vec<(u64, u64)>, desc_handler: &mut F, ) -> anyhow::Result<()> where F: FnMut(GuestAddress, u64),259 fn release_ranges<F>(
260     release_memory_tube: Option<&Tube>,
261     inflate_ranges: Vec<(u64, u64)>,
262     desc_handler: &mut F,
263 ) -> anyhow::Result<()>
264 where
265     F: FnMut(GuestAddress, u64),
266 {
267     if let Some(tube) = release_memory_tube {
268         let unpin_ranges = inflate_ranges
269             .iter()
270             .map(|v| {
271                 (
272                     v.0 >> VIRTIO_BALLOON_PFN_SHIFT,
273                     v.1 / VIRTIO_BALLOON_PF_SIZE,
274                 )
275             })
276             .collect();
277         let req = UnpinRequest {
278             ranges: unpin_ranges,
279         };
280         if let Err(e) = tube.send(&req) {
281             error!("failed to send unpin request: {}", e);
282         } else {
283             match tube.recv() {
284                 Ok(resp) => match resp {
285                     UnpinResponse::Success => invoke_desc_handler(inflate_ranges, desc_handler),
286                     UnpinResponse::Failed => error!("failed to handle unpin request"),
287                 },
288                 Err(e) => error!("failed to handle get unpin response: {}", e),
289             }
290         }
291     } else {
292         invoke_desc_handler(inflate_ranges, desc_handler);
293     }
294 
295     Ok(())
296 }
297 
298 // Processes one message's list of addresses.
handle_address_chain<F>( release_memory_tube: Option<&Tube>, avail_desc: &mut DescriptorChain, desc_handler: &mut F, ) -> anyhow::Result<()> where F: FnMut(GuestAddress, u64),299 fn handle_address_chain<F>(
300     release_memory_tube: Option<&Tube>,
301     avail_desc: &mut DescriptorChain,
302     desc_handler: &mut F,
303 ) -> anyhow::Result<()>
304 where
305     F: FnMut(GuestAddress, u64),
306 {
307     // In a long-running system, there is no reason to expect that
308     // a significant number of freed pages are consecutive. However,
309     // batching is relatively simple and can result in significant
310     // gains in a newly booted system, so it's worth attempting.
311     let mut range_start = 0;
312     let mut range_size = 0;
313     let mut inflate_ranges: Vec<(u64, u64)> = Vec::new();
314     for res in avail_desc.reader.iter::<Le32>() {
315         let pfn = match res {
316             Ok(pfn) => pfn,
317             Err(e) => {
318                 error!("error while reading unused pages: {}", e);
319                 break;
320             }
321         };
322         let guest_address = (u64::from(pfn.to_native())) << VIRTIO_BALLOON_PFN_SHIFT;
323         if range_start + range_size == guest_address {
324             range_size += VIRTIO_BALLOON_PF_SIZE;
325         } else if range_start == guest_address + VIRTIO_BALLOON_PF_SIZE {
326             range_start = guest_address;
327             range_size += VIRTIO_BALLOON_PF_SIZE;
328         } else {
329             // Discontinuity, so flush the previous range. Note range_size
330             // will be 0 on the first iteration, so skip that.
331             if range_size != 0 {
332                 inflate_ranges.push((range_start, range_size));
333             }
334             range_start = guest_address;
335             range_size = VIRTIO_BALLOON_PF_SIZE;
336         }
337     }
338     if range_size != 0 {
339         inflate_ranges.push((range_start, range_size));
340     }
341 
342     release_ranges(release_memory_tube, inflate_ranges, desc_handler)
343 }
344 
345 // Async task that handles the main balloon inflate and deflate queues.
handle_queue<F>( mut queue: Queue, mut queue_event: EventAsync, release_memory_tube: Option<&Tube>, mut desc_handler: F, mut stop_rx: oneshot::Receiver<()>, ) -> Queue where F: FnMut(GuestAddress, u64),346 async fn handle_queue<F>(
347     mut queue: Queue,
348     mut queue_event: EventAsync,
349     release_memory_tube: Option<&Tube>,
350     mut desc_handler: F,
351     mut stop_rx: oneshot::Receiver<()>,
352 ) -> Queue
353 where
354     F: FnMut(GuestAddress, u64),
355 {
356     loop {
357         let mut avail_desc = match queue
358             .next_async_interruptable(&mut queue_event, &mut stop_rx)
359             .await
360         {
361             Ok(Some(res)) => res,
362             Ok(None) => return queue,
363             Err(e) => {
364                 error!("Failed to read descriptor {}", e);
365                 return queue;
366             }
367         };
368         if let Err(e) =
369             handle_address_chain(release_memory_tube, &mut avail_desc, &mut desc_handler)
370         {
371             error!("balloon: failed to process inflate addresses: {}", e);
372         }
373         queue.add_used(avail_desc, 0);
374         queue.trigger_interrupt();
375     }
376 }
377 
378 // Processes one page-reporting descriptor.
handle_reported_buffer<F>( release_memory_tube: Option<&Tube>, avail_desc: &DescriptorChain, desc_handler: &mut F, ) -> anyhow::Result<()> where F: FnMut(GuestAddress, u64),379 fn handle_reported_buffer<F>(
380     release_memory_tube: Option<&Tube>,
381     avail_desc: &DescriptorChain,
382     desc_handler: &mut F,
383 ) -> anyhow::Result<()>
384 where
385     F: FnMut(GuestAddress, u64),
386 {
387     let reported_ranges: Vec<(u64, u64)> = avail_desc
388         .reader
389         .get_remaining_regions()
390         .chain(avail_desc.writer.get_remaining_regions())
391         .map(|r| (r.offset, r.len as u64))
392         .collect();
393 
394     release_ranges(release_memory_tube, reported_ranges, desc_handler)
395 }
396 
397 // Async task that handles the page reporting queue.
handle_reporting_queue<F>( mut queue: Queue, mut queue_event: EventAsync, release_memory_tube: Option<&Tube>, mut desc_handler: F, mut stop_rx: oneshot::Receiver<()>, ) -> Queue where F: FnMut(GuestAddress, u64),398 async fn handle_reporting_queue<F>(
399     mut queue: Queue,
400     mut queue_event: EventAsync,
401     release_memory_tube: Option<&Tube>,
402     mut desc_handler: F,
403     mut stop_rx: oneshot::Receiver<()>,
404 ) -> Queue
405 where
406     F: FnMut(GuestAddress, u64),
407 {
408     loop {
409         let avail_desc = match queue
410             .next_async_interruptable(&mut queue_event, &mut stop_rx)
411             .await
412         {
413             Ok(Some(res)) => res,
414             Ok(None) => return queue,
415             Err(e) => {
416                 error!("Failed to read descriptor {}", e);
417                 return queue;
418             }
419         };
420         if let Err(e) = handle_reported_buffer(release_memory_tube, &avail_desc, &mut desc_handler)
421         {
422             error!("balloon: failed to process reported buffer: {}", e);
423         }
424         queue.add_used(avail_desc, 0);
425         queue.trigger_interrupt();
426     }
427 }
428 
parse_balloon_stats(reader: &mut Reader) -> BalloonStats429 fn parse_balloon_stats(reader: &mut Reader) -> BalloonStats {
430     let mut stats: BalloonStats = Default::default();
431     for res in reader.iter::<BalloonStat>() {
432         match res {
433             Ok(stat) => stat.update_stats(&mut stats),
434             Err(e) => {
435                 error!("error while reading stats: {}", e);
436                 break;
437             }
438         };
439     }
440     stats
441 }
442 
443 // Async task that handles the stats queue. Note that the cadence of this is driven by requests for
444 // balloon stats from the control pipe.
445 // The guests queues an initial buffer on boot, which is read and then this future will block until
446 // signaled from the command socket that stats should be collected again.
handle_stats_queue( mut queue: Queue, mut queue_event: EventAsync, mut stats_rx: mpsc::Receiver<()>, command_tube: &AsyncTube, #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>, state: Arc<AsyncRwLock<BalloonState>>, mut stop_rx: oneshot::Receiver<()>, ) -> Queue447 async fn handle_stats_queue(
448     mut queue: Queue,
449     mut queue_event: EventAsync,
450     mut stats_rx: mpsc::Receiver<()>,
451     command_tube: &AsyncTube,
452     #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
453     state: Arc<AsyncRwLock<BalloonState>>,
454     mut stop_rx: oneshot::Receiver<()>,
455 ) -> Queue {
456     let mut avail_desc = match queue
457         .next_async_interruptable(&mut queue_event, &mut stop_rx)
458         .await
459     {
460         // Consume the first stats buffer sent from the guest at startup. It was not
461         // requested by anyone, and the stats are stale.
462         Ok(Some(res)) => res,
463         Ok(None) => return queue,
464         Err(e) => {
465             error!("Failed to read descriptor {}", e);
466             return queue;
467         }
468     };
469 
470     loop {
471         select_biased! {
472             msg = stats_rx.next() => {
473                 // Wait for a request to read the stats.
474                 match msg {
475                     Some(()) => (),
476                     None => {
477                         error!("stats signal channel was closed");
478                         return queue;
479                     }
480                 }
481             }
482             _ = stop_rx => return queue,
483         };
484 
485         // Request a new stats_desc to the guest.
486         queue.add_used(avail_desc, 0);
487         queue.trigger_interrupt();
488 
489         avail_desc = match queue.next_async(&mut queue_event).await {
490             Err(e) => {
491                 error!("Failed to read descriptor {}", e);
492                 return queue;
493             }
494             Ok(d) => d,
495         };
496         let stats = parse_balloon_stats(&mut avail_desc.reader);
497 
498         let actual_pages = state.lock().await.actual_pages as u64;
499         let result = BalloonTubeResult::Stats {
500             balloon_actual: actual_pages << VIRTIO_BALLOON_PFN_SHIFT,
501             stats,
502         };
503         let send_result = command_tube.send(result).await;
504         if let Err(e) = send_result {
505             error!("failed to send stats result: {}", e);
506         }
507 
508         #[cfg(feature = "registered_events")]
509         if let Some(registered_evt_q) = registered_evt_q {
510             if let Err(e) = registered_evt_q
511                 .send(&RegisteredEventWithData::VirtioBalloonResize)
512                 .await
513             {
514                 error!("failed to send VirtioBalloonResize event: {}", e);
515             }
516         }
517     }
518 }
519 
send_adjusted_response( tube: &AsyncTube, num_pages: u32, ) -> std::result::Result<(), base::TubeError>520 async fn send_adjusted_response(
521     tube: &AsyncTube,
522     num_pages: u32,
523 ) -> std::result::Result<(), base::TubeError> {
524     let num_bytes = (num_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
525     let result = BalloonTubeResult::Adjusted { num_bytes };
526     tube.send(result).await
527 }
528 
529 enum WSOp {
530     WSReport,
531     WSConfig {
532         bins: Vec<u32>,
533         refresh_threshold: u32,
534         report_threshold: u32,
535     },
536 }
537 
handle_ws_op_queue( mut queue: Queue, mut queue_event: EventAsync, mut ws_op_rx: mpsc::Receiver<WSOp>, state: Arc<AsyncRwLock<BalloonState>>, mut stop_rx: oneshot::Receiver<()>, ) -> Result<Queue>538 async fn handle_ws_op_queue(
539     mut queue: Queue,
540     mut queue_event: EventAsync,
541     mut ws_op_rx: mpsc::Receiver<WSOp>,
542     state: Arc<AsyncRwLock<BalloonState>>,
543     mut stop_rx: oneshot::Receiver<()>,
544 ) -> Result<Queue> {
545     loop {
546         let op = select_biased! {
547             next_op = ws_op_rx.next().fuse() => {
548                 match next_op {
549                     Some(op) => op,
550                     None => {
551                         error!("ws op tube was closed");
552                         break;
553                     }
554                 }
555             }
556             _ = stop_rx => {
557                 break;
558             }
559         };
560         let mut avail_desc = queue
561             .next_async(&mut queue_event)
562             .await
563             .map_err(BalloonError::AsyncAwait)?;
564         let writer = &mut avail_desc.writer;
565 
566         match op {
567             WSOp::WSReport => {
568                 {
569                     let mut state = state.lock().await;
570                     state.expecting_ws = true;
571                 }
572 
573                 let ws_r = virtio_balloon_op {
574                     type_: VIRTIO_BALLOON_WS_OP_REQUEST.into(),
575                 };
576 
577                 writer.write_obj(ws_r).map_err(BalloonError::WriteQueue)?;
578             }
579             WSOp::WSConfig {
580                 bins,
581                 refresh_threshold,
582                 report_threshold,
583             } => {
584                 let cmd = virtio_balloon_op {
585                     type_: VIRTIO_BALLOON_WS_OP_CONFIG.into(),
586                 };
587 
588                 writer.write_obj(cmd).map_err(BalloonError::WriteQueue)?;
589                 writer
590                     .write_all(bins.as_bytes())
591                     .map_err(BalloonError::WriteQueue)?;
592                 writer
593                     .write_obj(refresh_threshold)
594                     .map_err(BalloonError::WriteQueue)?;
595                 writer
596                     .write_obj(report_threshold)
597                     .map_err(BalloonError::WriteQueue)?;
598             }
599         }
600 
601         let len = writer.bytes_written() as u32;
602         queue.add_used(avail_desc, len);
603         queue.trigger_interrupt();
604     }
605 
606     Ok(queue)
607 }
608 
parse_balloon_ws(reader: &mut Reader) -> BalloonWS609 fn parse_balloon_ws(reader: &mut Reader) -> BalloonWS {
610     let mut ws = BalloonWS::new();
611     for res in reader.iter::<virtio_balloon_ws>() {
612         match res {
613             Ok(ws_msg) => {
614                 ws_msg.update_ws(&mut ws);
615             }
616             Err(e) => {
617                 error!("error while reading ws: {}", e);
618                 break;
619             }
620         }
621     }
622     if ws.ws.len() < VIRTIO_BALLOON_WS_MIN_NUM_BINS || ws.ws.len() > VIRTIO_BALLOON_WS_MAX_NUM_BINS
623     {
624         error!("unexpected number of WS buckets: {}", ws.ws.len());
625     }
626     ws
627 }
628 
629 // Async task that handles the stats queue. Note that the arrival of events on
630 // the WS vq may be the result of either a WS request (WS-R) command having
631 // been sent to the guest, or an unprompted send due to memory pressue in the
632 // guest. If the data was requested, we should also send that back on the
633 // command tube.
handle_ws_data_queue( mut queue: Queue, mut queue_event: EventAsync, command_tube: &AsyncTube, #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>, state: Arc<AsyncRwLock<BalloonState>>, mut stop_rx: oneshot::Receiver<()>, ) -> Result<Queue>634 async fn handle_ws_data_queue(
635     mut queue: Queue,
636     mut queue_event: EventAsync,
637     command_tube: &AsyncTube,
638     #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
639     state: Arc<AsyncRwLock<BalloonState>>,
640     mut stop_rx: oneshot::Receiver<()>,
641 ) -> Result<Queue> {
642     loop {
643         let mut avail_desc = match queue
644             .next_async_interruptable(&mut queue_event, &mut stop_rx)
645             .await
646             .map_err(BalloonError::AsyncAwait)?
647         {
648             Some(res) => res,
649             None => return Ok(queue),
650         };
651 
652         let ws = parse_balloon_ws(&mut avail_desc.reader);
653 
654         let mut state = state.lock().await;
655 
656         // update ws report with balloon pages now that we have a lock on state
657         let balloon_actual = (state.actual_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
658 
659         if state.expecting_ws {
660             let result = BalloonTubeResult::WorkingSet { ws, balloon_actual };
661             let send_result = command_tube.send(result).await;
662             if let Err(e) = send_result {
663                 error!("failed to send ws result: {}", e);
664             }
665 
666             state.expecting_ws = false;
667         } else {
668             #[cfg(feature = "registered_events")]
669             if let Some(registered_evt_q) = registered_evt_q {
670                 if let Err(e) = registered_evt_q
671                     .send(RegisteredEventWithData::from_ws(&ws, balloon_actual))
672                     .await
673                 {
674                     error!("failed to send VirtioBalloonWSReport event: {}", e);
675                 }
676             }
677         }
678 
679         queue.add_used(avail_desc, 0);
680         queue.trigger_interrupt();
681     }
682 }
683 
684 // Async task that handles the command socket. The command socket handles messages from the host
685 // requesting that the guest balloon be adjusted or to report guest memory statistics.
handle_command_tube( command_tube: &AsyncTube, interrupt: Interrupt, state: Arc<AsyncRwLock<BalloonState>>, mut stats_tx: mpsc::Sender<()>, mut ws_op_tx: mpsc::Sender<WSOp>, mut stop_rx: oneshot::Receiver<()>, ) -> Result<()>686 async fn handle_command_tube(
687     command_tube: &AsyncTube,
688     interrupt: Interrupt,
689     state: Arc<AsyncRwLock<BalloonState>>,
690     mut stats_tx: mpsc::Sender<()>,
691     mut ws_op_tx: mpsc::Sender<WSOp>,
692     mut stop_rx: oneshot::Receiver<()>,
693 ) -> Result<()> {
694     loop {
695         let cmd_res = select_biased! {
696             res = command_tube.next().fuse() => res,
697             _ = stop_rx => return Ok(())
698         };
699         match cmd_res {
700             Ok(command) => match command {
701                 BalloonTubeCommand::Adjust {
702                     num_bytes,
703                     allow_failure,
704                 } => {
705                     let num_pages = (num_bytes >> VIRTIO_BALLOON_PFN_SHIFT) as u32;
706                     let mut state = state.lock().await;
707 
708                     state.num_pages = num_pages;
709                     interrupt.signal_config_changed();
710 
711                     if allow_failure {
712                         if num_pages == state.actual_pages {
713                             send_adjusted_response(command_tube, num_pages)
714                                 .await
715                                 .map_err(BalloonError::SendResponse)?;
716                         } else {
717                             state.failable_update = true;
718                         }
719                     }
720                 }
721                 BalloonTubeCommand::WorkingSetConfig {
722                     bins,
723                     refresh_threshold,
724                     report_threshold,
725                 } => {
726                     if let Err(e) = ws_op_tx.try_send(WSOp::WSConfig {
727                         bins,
728                         refresh_threshold,
729                         report_threshold,
730                     }) {
731                         error!("failed to send config to ws handler: {}", e);
732                     }
733                 }
734                 BalloonTubeCommand::Stats => {
735                     if let Err(e) = stats_tx.try_send(()) {
736                         error!("failed to signal the stat handler: {}", e);
737                     }
738                 }
739                 BalloonTubeCommand::WorkingSet => {
740                     if let Err(e) = ws_op_tx.try_send(WSOp::WSReport) {
741                         error!("failed to send report request to ws handler: {}", e);
742                     }
743                 }
744             },
745             #[cfg(windows)]
746             Err(base::TubeError::Recv(e)) if e.kind() == std::io::ErrorKind::TimedOut => {
747                 // On Windows, async IO tasks like the next/recv above are cancelled as the VM is
748                 // shutting down. For the sake of consistency with unix, we can't *just* return
749                 // here; instead, we wait for the stop request to arrive, *and then* return.
750                 //
751                 // The real fix is to get rid of the global unblock pool, since then we won't
752                 // cancel the tasks early (b/196911556).
753                 let _ = stop_rx.await;
754                 return Ok(());
755             }
756             Err(e) => {
757                 return Err(BalloonError::ReceivingCommand(e));
758             }
759         }
760     }
761 }
762 
handle_pending_adjusted_responses( pending_adjusted_response_event: EventAsync, command_tube: &AsyncTube, state: Arc<AsyncRwLock<BalloonState>>, ) -> Result<()>763 async fn handle_pending_adjusted_responses(
764     pending_adjusted_response_event: EventAsync,
765     command_tube: &AsyncTube,
766     state: Arc<AsyncRwLock<BalloonState>>,
767 ) -> Result<()> {
768     loop {
769         pending_adjusted_response_event
770             .next_val()
771             .await
772             .map_err(BalloonError::AsyncAwait)?;
773         while let Some(num_pages) = state.lock().await.pending_adjusted_responses.pop_front() {
774             send_adjusted_response(command_tube, num_pages)
775                 .await
776                 .map_err(BalloonError::SendResponse)?;
777         }
778     }
779 }
780 
781 /// Represents queues & events for the balloon device.
782 struct BalloonQueues {
783     inflate: Queue,
784     deflate: Queue,
785     stats: Option<Queue>,
786     reporting: Option<Queue>,
787     ws_data: Option<Queue>,
788     ws_op: Option<Queue>,
789 }
790 
791 impl BalloonQueues {
new(inflate: Queue, deflate: Queue) -> Self792     fn new(inflate: Queue, deflate: Queue) -> Self {
793         BalloonQueues {
794             inflate,
795             deflate,
796             stats: None,
797             reporting: None,
798             ws_data: None,
799             ws_op: None,
800         }
801     }
802 }
803 
804 /// When the worker is stopped, the queues are preserved here.
805 struct PausedQueues {
806     inflate: Queue,
807     deflate: Queue,
808     stats: Option<Queue>,
809     reporting: Option<Queue>,
810     ws_data: Option<Queue>,
811     ws_op: Option<Queue>,
812 }
813 
814 impl PausedQueues {
new(inflate: Queue, deflate: Queue) -> Self815     fn new(inflate: Queue, deflate: Queue) -> Self {
816         PausedQueues {
817             inflate,
818             deflate,
819             stats: None,
820             reporting: None,
821             ws_data: None,
822             ws_op: None,
823         }
824     }
825 }
826 
apply_if_some<F, R>(queue_opt: Option<Queue>, mut func: F) where F: FnMut(Queue) -> R,827 fn apply_if_some<F, R>(queue_opt: Option<Queue>, mut func: F)
828 where
829     F: FnMut(Queue) -> R,
830 {
831     if let Some(queue) = queue_opt {
832         func(queue);
833     }
834 }
835 
836 impl From<Box<PausedQueues>> for BTreeMap<usize, Queue> {
from(queues: Box<PausedQueues>) -> BTreeMap<usize, Queue>837     fn from(queues: Box<PausedQueues>) -> BTreeMap<usize, Queue> {
838         let mut ret = Vec::new();
839         ret.push(queues.inflate);
840         ret.push(queues.deflate);
841         apply_if_some(queues.stats, |stats| ret.push(stats));
842         apply_if_some(queues.reporting, |reporting| ret.push(reporting));
843         apply_if_some(queues.ws_data, |ws_data| ret.push(ws_data));
844         apply_if_some(queues.ws_op, |ws_op| ret.push(ws_op));
845         // WARNING: We don't use the indices from the virito spec on purpose, see comment in
846         // get_queues_from_map for the rationale.
847         ret.into_iter().enumerate().collect()
848     }
849 }
850 
851 /// Stores data from the worker when it stops so that data can be re-used when
852 /// the worker is restarted.
853 struct WorkerReturn {
854     release_memory_tube: Option<Tube>,
855     command_tube: Tube,
856     #[cfg(feature = "registered_events")]
857     registered_evt_q: Option<SendTube>,
858     paused_queues: Option<PausedQueues>,
859     #[cfg(windows)]
860     vm_memory_client: VmMemoryClient,
861 }
862 
863 // The main worker thread. Initialized the asynchronous worker tasks and passes them to the executor
864 // to be processed.
run_worker( inflate_queue: Queue, deflate_queue: Queue, stats_queue: Option<Queue>, reporting_queue: Option<Queue>, ws_data_queue: Option<Queue>, ws_op_queue: Option<Queue>, command_tube: Tube, #[cfg(windows)] vm_memory_client: VmMemoryClient, release_memory_tube: Option<Tube>, interrupt: Interrupt, kill_evt: Event, target_reached_evt: Event, pending_adjusted_response_event: Event, mem: GuestMemory, state: Arc<AsyncRwLock<BalloonState>>, #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>, ) -> WorkerReturn865 fn run_worker(
866     inflate_queue: Queue,
867     deflate_queue: Queue,
868     stats_queue: Option<Queue>,
869     reporting_queue: Option<Queue>,
870     ws_data_queue: Option<Queue>,
871     ws_op_queue: Option<Queue>,
872     command_tube: Tube,
873     #[cfg(windows)] vm_memory_client: VmMemoryClient,
874     release_memory_tube: Option<Tube>,
875     interrupt: Interrupt,
876     kill_evt: Event,
877     target_reached_evt: Event,
878     pending_adjusted_response_event: Event,
879     mem: GuestMemory,
880     state: Arc<AsyncRwLock<BalloonState>>,
881     #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
882 ) -> WorkerReturn {
883     let ex = Executor::new().unwrap();
884     let command_tube = AsyncTube::new(&ex, command_tube).unwrap();
885     #[cfg(feature = "registered_events")]
886     let registered_evt_q_async = registered_evt_q
887         .as_ref()
888         .map(|q| SendTubeAsync::new(q.try_clone().unwrap(), &ex).unwrap());
889 
890     let mut stop_queue_oneshots = Vec::new();
891 
892     // We need a block to release all references to command_tube at the end before returning it.
893     let paused_queues = {
894         // The first queue is used for inflate messages
895         let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
896         let inflate_queue_evt = inflate_queue
897             .event()
898             .try_clone()
899             .expect("failed to clone queue event");
900         let inflate = handle_queue(
901             inflate_queue,
902             EventAsync::new(inflate_queue_evt, &ex).expect("failed to create async event"),
903             release_memory_tube.as_ref(),
904             |guest_address, len| {
905                 sys::free_memory(
906                     &guest_address,
907                     len,
908                     #[cfg(windows)]
909                     &vm_memory_client,
910                     #[cfg(any(target_os = "android", target_os = "linux"))]
911                     &mem,
912                 )
913             },
914             stop_rx,
915         );
916         let inflate = inflate.fuse();
917         pin_mut!(inflate);
918 
919         // The second queue is used for deflate messages
920         let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
921         let deflate_queue_evt = deflate_queue
922             .event()
923             .try_clone()
924             .expect("failed to clone queue event");
925         let deflate = handle_queue(
926             deflate_queue,
927             EventAsync::new(deflate_queue_evt, &ex).expect("failed to create async event"),
928             None,
929             |guest_address, len| {
930                 sys::reclaim_memory(
931                     &guest_address,
932                     len,
933                     #[cfg(windows)]
934                     &vm_memory_client,
935                 )
936             },
937             stop_rx,
938         );
939         let deflate = deflate.fuse();
940         pin_mut!(deflate);
941 
942         // The next queue is used for stats messages if VIRTIO_BALLOON_F_STATS_VQ is negotiated.
943         let (stats_tx, stats_rx) = mpsc::channel::<()>(1);
944         let has_stats_queue = stats_queue.is_some();
945         let stats = if let Some(stats_queue) = stats_queue {
946             let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
947             let stats_queue_evt = stats_queue
948                 .event()
949                 .try_clone()
950                 .expect("failed to clone queue event");
951             handle_stats_queue(
952                 stats_queue,
953                 EventAsync::new(stats_queue_evt, &ex).expect("failed to create async event"),
954                 stats_rx,
955                 &command_tube,
956                 #[cfg(feature = "registered_events")]
957                 registered_evt_q_async.as_ref(),
958                 state.clone(),
959                 stop_rx,
960             )
961             .left_future()
962         } else {
963             std::future::pending().right_future()
964         };
965         let stats = stats.fuse();
966         pin_mut!(stats);
967 
968         // The next queue is used for reporting messages
969         let has_reporting_queue = reporting_queue.is_some();
970         let reporting = if let Some(reporting_queue) = reporting_queue {
971             let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
972             let reporting_queue_evt = reporting_queue
973                 .event()
974                 .try_clone()
975                 .expect("failed to clone queue event");
976             handle_reporting_queue(
977                 reporting_queue,
978                 EventAsync::new(reporting_queue_evt, &ex).expect("failed to create async event"),
979                 release_memory_tube.as_ref(),
980                 |guest_address, len| {
981                     sys::free_memory(
982                         &guest_address,
983                         len,
984                         #[cfg(windows)]
985                         &vm_memory_client,
986                         #[cfg(any(target_os = "android", target_os = "linux"))]
987                         &mem,
988                     )
989                 },
990                 stop_rx,
991             )
992             .left_future()
993         } else {
994             std::future::pending().right_future()
995         };
996         let reporting = reporting.fuse();
997         pin_mut!(reporting);
998 
999         // If VIRTIO_BALLOON_F_WS_REPORTING is set 2 queues must handled - one for WS data and one
1000         // for WS notifications.
1001         let has_ws_data_queue = ws_data_queue.is_some();
1002         let ws_data = if let Some(ws_data_queue) = ws_data_queue {
1003             let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1004             let ws_data_queue_evt = ws_data_queue
1005                 .event()
1006                 .try_clone()
1007                 .expect("failed to clone queue event");
1008             handle_ws_data_queue(
1009                 ws_data_queue,
1010                 EventAsync::new(ws_data_queue_evt, &ex).expect("failed to create async event"),
1011                 &command_tube,
1012                 #[cfg(feature = "registered_events")]
1013                 registered_evt_q_async.as_ref(),
1014                 state.clone(),
1015                 stop_rx,
1016             )
1017             .left_future()
1018         } else {
1019             std::future::pending().right_future()
1020         };
1021         let ws_data = ws_data.fuse();
1022         pin_mut!(ws_data);
1023 
1024         let (ws_op_tx, ws_op_rx) = mpsc::channel::<WSOp>(1);
1025         let has_ws_op_queue = ws_op_queue.is_some();
1026         let ws_op = if let Some(ws_op_queue) = ws_op_queue {
1027             let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1028             let ws_op_queue_evt = ws_op_queue
1029                 .event()
1030                 .try_clone()
1031                 .expect("failed to clone queue event");
1032             handle_ws_op_queue(
1033                 ws_op_queue,
1034                 EventAsync::new(ws_op_queue_evt, &ex).expect("failed to create async event"),
1035                 ws_op_rx,
1036                 state.clone(),
1037                 stop_rx,
1038             )
1039             .left_future()
1040         } else {
1041             std::future::pending().right_future()
1042         };
1043         let ws_op = ws_op.fuse();
1044         pin_mut!(ws_op);
1045 
1046         // Future to handle command messages that resize the balloon.
1047         let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1048         let command = handle_command_tube(
1049             &command_tube,
1050             interrupt.clone(),
1051             state.clone(),
1052             stats_tx,
1053             ws_op_tx,
1054             stop_rx,
1055         );
1056         pin_mut!(command);
1057 
1058         // Process any requests to resample the irq value.
1059         let resample = async_utils::handle_irq_resample(&ex, interrupt.clone());
1060         pin_mut!(resample);
1061 
1062         // Send a message if balloon target reached event is triggered.
1063         let target_reached = handle_target_reached(
1064             &ex,
1065             target_reached_evt,
1066             #[cfg(windows)]
1067             &vm_memory_client,
1068         );
1069         pin_mut!(target_reached);
1070 
1071         // Exit if the kill event is triggered.
1072         let kill = async_utils::await_and_exit(&ex, kill_evt);
1073         pin_mut!(kill);
1074 
1075         let pending_adjusted = handle_pending_adjusted_responses(
1076             EventAsync::new(pending_adjusted_response_event, &ex)
1077                 .expect("failed to create async event"),
1078             &command_tube,
1079             state,
1080         );
1081         pin_mut!(pending_adjusted);
1082 
1083         let res = ex.run_until(async {
1084             select! {
1085                 _ = kill.fuse() => (),
1086                 _ = inflate => return Err(anyhow!("inflate stopped unexpectedly")),
1087                 _ = deflate => return Err(anyhow!("deflate stopped unexpectedly")),
1088                 _ = stats => return Err(anyhow!("stats stopped unexpectedly")),
1089                 _ = reporting => return Err(anyhow!("reporting stopped unexpectedly")),
1090                 _ = command.fuse() => return Err(anyhow!("command stopped unexpectedly")),
1091                 _ = ws_op => return Err(anyhow!("ws_op stopped unexpectedly")),
1092                 _ = resample.fuse() => return Err(anyhow!("resample stopped unexpectedly")),
1093                 _ = pending_adjusted.fuse() => return Err(anyhow!("pending_adjusted stopped unexpectedly")),
1094                 _ = ws_data => return Err(anyhow!("ws_data stopped unexpectedly")),
1095                 _ = target_reached.fuse() => return Err(anyhow!("target_reached stopped unexpectedly")),
1096             }
1097 
1098             // Worker is shutting down. To recover the queues, we have to signal
1099             // all the queue futures to exit.
1100             for stop_tx in stop_queue_oneshots {
1101                 if stop_tx.send(()).is_err() {
1102                     return Err(anyhow!("failed to request stop for queue future"));
1103                 }
1104             }
1105 
1106             // Collect all the queues (awaiting any queue future should now
1107             // return its Queue immediately).
1108             let mut paused_queues = PausedQueues::new(
1109                 inflate.await,
1110                 deflate.await,
1111             );
1112             if has_reporting_queue {
1113                 paused_queues.reporting = Some(reporting.await);
1114             }
1115             if has_stats_queue {
1116                 paused_queues.stats = Some(stats.await);
1117             }
1118             if has_ws_data_queue {
1119                 paused_queues.ws_data = Some(ws_data.await.context("failed to stop ws_data queue")?);
1120             }
1121             if has_ws_op_queue {
1122                 paused_queues.ws_op = Some(ws_op.await.context("failed to stop ws_op queue")?);
1123             }
1124             Ok(paused_queues)
1125         });
1126 
1127         match res {
1128             Err(e) => {
1129                 error!("error happened in executor: {}", e);
1130                 None
1131             }
1132             Ok(main_future_res) => match main_future_res {
1133                 Ok(paused_queues) => Some(paused_queues),
1134                 Err(e) => {
1135                     error!("error happened in main balloon future: {}", e);
1136                     None
1137                 }
1138             },
1139         }
1140     };
1141 
1142     WorkerReturn {
1143         command_tube: command_tube.into(),
1144         paused_queues,
1145         release_memory_tube,
1146         #[cfg(feature = "registered_events")]
1147         registered_evt_q,
1148         #[cfg(windows)]
1149         vm_memory_client,
1150     }
1151 }
1152 
handle_target_reached( ex: &Executor, target_reached_evt: Event, #[cfg(windows)] vm_memory_client: &VmMemoryClient, ) -> anyhow::Result<()>1153 async fn handle_target_reached(
1154     ex: &Executor,
1155     target_reached_evt: Event,
1156     #[cfg(windows)] vm_memory_client: &VmMemoryClient,
1157 ) -> anyhow::Result<()> {
1158     let event_async =
1159         EventAsync::new(target_reached_evt, ex).context("failed to create EventAsync")?;
1160     loop {
1161         // Wait for target reached trigger.
1162         let _ = event_async.next_val().await;
1163         // Send the message to vm_control on the event. We don't have to read the current
1164         // size yet.
1165         sys::balloon_target_reached(
1166             0,
1167             #[cfg(windows)]
1168             vm_memory_client,
1169         );
1170     }
1171     // The above loop will never terminate and there is no reason to terminate it either. However,
1172     // the function is used in an executor that expects a Result<> return. Make sure that clippy
1173     // doesn't enforce the unreachable_code condition.
1174     #[allow(unreachable_code)]
1175     Ok(())
1176 }
1177 
1178 /// Virtio device for memory balloon inflation/deflation.
1179 pub struct Balloon {
1180     command_tube: Option<Tube>,
1181     #[cfg(windows)]
1182     vm_memory_client: Option<VmMemoryClient>,
1183     release_memory_tube: Option<Tube>,
1184     pending_adjusted_response_event: Event,
1185     state: Arc<AsyncRwLock<BalloonState>>,
1186     features: u64,
1187     acked_features: u64,
1188     worker_thread: Option<WorkerThread<WorkerReturn>>,
1189     #[cfg(feature = "registered_events")]
1190     registered_evt_q: Option<SendTube>,
1191     ws_num_bins: u8,
1192     target_reached_evt: Option<Event>,
1193 }
1194 
1195 /// Snapshot of the [Balloon] state.
1196 #[derive(Serialize, Deserialize)]
1197 struct BalloonSnapshot {
1198     state: BalloonState,
1199     features: u64,
1200     acked_features: u64,
1201     ws_num_bins: u8,
1202 }
1203 
1204 impl Balloon {
1205     /// Creates a new virtio balloon device.
1206     /// To let Balloon able to successfully release the memory which are pinned
1207     /// by CoIOMMU to host, the release_memory_tube will be used to send the inflate
1208     /// ranges to CoIOMMU with UnpinRequest/UnpinResponse messages, so that The
1209     /// memory in the inflate range can be unpinned first.
new( base_features: u64, command_tube: Tube, #[cfg(windows)] vm_memory_client: VmMemoryClient, release_memory_tube: Option<Tube>, init_balloon_size: u64, enabled_features: u64, #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>, ws_num_bins: u8, ) -> Result<Balloon>1210     pub fn new(
1211         base_features: u64,
1212         command_tube: Tube,
1213         #[cfg(windows)] vm_memory_client: VmMemoryClient,
1214         release_memory_tube: Option<Tube>,
1215         init_balloon_size: u64,
1216         enabled_features: u64,
1217         #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
1218         ws_num_bins: u8,
1219     ) -> Result<Balloon> {
1220         let features = base_features
1221             | 1 << VIRTIO_BALLOON_F_MUST_TELL_HOST
1222             | 1 << VIRTIO_BALLOON_F_STATS_VQ
1223             | 1 << VIRTIO_BALLOON_F_DEFLATE_ON_OOM
1224             | enabled_features;
1225 
1226         Ok(Balloon {
1227             command_tube: Some(command_tube),
1228             #[cfg(windows)]
1229             vm_memory_client: Some(vm_memory_client),
1230             release_memory_tube,
1231             pending_adjusted_response_event: Event::new().map_err(BalloonError::CreatingEvent)?,
1232             state: Arc::new(AsyncRwLock::new(BalloonState {
1233                 num_pages: (init_balloon_size >> VIRTIO_BALLOON_PFN_SHIFT) as u32,
1234                 actual_pages: 0,
1235                 failable_update: false,
1236                 pending_adjusted_responses: VecDeque::new(),
1237                 expecting_ws: false,
1238             })),
1239             worker_thread: None,
1240             features,
1241             acked_features: 0,
1242             #[cfg(feature = "registered_events")]
1243             registered_evt_q,
1244             ws_num_bins,
1245             target_reached_evt: None,
1246         })
1247     }
1248 
get_config(&self) -> virtio_balloon_config1249     fn get_config(&self) -> virtio_balloon_config {
1250         let state = block_on(self.state.lock());
1251         virtio_balloon_config {
1252             num_pages: state.num_pages.into(),
1253             actual: state.actual_pages.into(),
1254             // crosvm does not (currently) use free_page_hint_cmd_id or
1255             // poison_val, but they must be present in the right order and size
1256             // for the virtio-balloon driver in the guest to deserialize the
1257             // config correctly.
1258             free_page_hint_cmd_id: 0.into(),
1259             poison_val: 0.into(),
1260             ws_num_bins: self.ws_num_bins,
1261             _reserved: [0, 0, 0],
1262         }
1263     }
1264 
stop_worker(&mut self) -> StoppedWorker<PausedQueues>1265     fn stop_worker(&mut self) -> StoppedWorker<PausedQueues> {
1266         if let Some(worker_thread) = self.worker_thread.take() {
1267             let worker_ret = worker_thread.stop();
1268             self.release_memory_tube = worker_ret.release_memory_tube;
1269             self.command_tube = Some(worker_ret.command_tube);
1270             #[cfg(feature = "registered_events")]
1271             {
1272                 self.registered_evt_q = worker_ret.registered_evt_q;
1273             }
1274             #[cfg(windows)]
1275             {
1276                 self.vm_memory_client = Some(worker_ret.vm_memory_client);
1277             }
1278 
1279             if let Some(queues) = worker_ret.paused_queues {
1280                 StoppedWorker::WithQueues(Box::new(queues))
1281             } else {
1282                 StoppedWorker::MissingQueues
1283             }
1284         } else {
1285             StoppedWorker::AlreadyStopped
1286         }
1287     }
1288 
1289     /// Given a filtered queue vector from [VirtioDevice::activate], extract
1290     /// the queues (accounting for queues that are missing because the features
1291     /// are not negotiated) into a structure that is easier to work with.
get_queues_from_map( &self, mut queues: BTreeMap<usize, Queue>, ) -> anyhow::Result<BalloonQueues>1292     fn get_queues_from_map(
1293         &self,
1294         mut queues: BTreeMap<usize, Queue>,
1295     ) -> anyhow::Result<BalloonQueues> {
1296         fn pop_queue(
1297             queues: &mut BTreeMap<usize, Queue>,
1298             expected_index: usize,
1299             name: &str,
1300         ) -> anyhow::Result<Queue> {
1301             let (queue_index, queue) = queues
1302                 .pop_first()
1303                 .with_context(|| format!("missing {}", name))?;
1304 
1305             if queue_index == expected_index {
1306                 debug!("{name} index {queue_index}");
1307             } else {
1308                 warn!("expected {name} index {expected_index}, got {queue_index}");
1309             }
1310 
1311             Ok(queue)
1312         }
1313 
1314         // WARNING: We use `pop_first` instead of explicitly using the indices from the virtio spec
1315         // because the Linux virtio drivers only "allocates" queue indices that are used, so queues
1316         // need to be removed in order of ascending virtqueue index.
1317         let inflate_queue = pop_queue(&mut queues, INFLATEQ, "inflateq")?;
1318         let deflate_queue = pop_queue(&mut queues, DEFLATEQ, "deflateq")?;
1319         let mut queue_struct = BalloonQueues::new(inflate_queue, deflate_queue);
1320 
1321         if self.acked_features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1322             queue_struct.stats = Some(pop_queue(&mut queues, STATSQ, "statsq")?);
1323         }
1324         if self.acked_features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1325             queue_struct.reporting = Some(pop_queue(&mut queues, REPORTING_VQ, "reporting_vq")?);
1326         }
1327         if self.acked_features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1328             queue_struct.ws_data = Some(pop_queue(&mut queues, WS_DATA_VQ, "ws_data_vq")?);
1329             queue_struct.ws_op = Some(pop_queue(&mut queues, WS_OP_VQ, "ws_op_vq")?);
1330         }
1331 
1332         if !queues.is_empty() {
1333             return Err(anyhow!("unexpected queues {:?}", queues.into_keys()));
1334         }
1335 
1336         Ok(queue_struct)
1337     }
1338 
start_worker( &mut self, mem: GuestMemory, interrupt: Interrupt, queues: BalloonQueues, ) -> anyhow::Result<()>1339     fn start_worker(
1340         &mut self,
1341         mem: GuestMemory,
1342         interrupt: Interrupt,
1343         queues: BalloonQueues,
1344     ) -> anyhow::Result<()> {
1345         let (self_target_reached_evt, target_reached_evt) = Event::new()
1346             .and_then(|e| Ok((e.try_clone()?, e)))
1347             .context("failed to create target_reached Event pair: {}")?;
1348         self.target_reached_evt = Some(self_target_reached_evt);
1349 
1350         let state = self.state.clone();
1351 
1352         let command_tube = self.command_tube.take().unwrap();
1353 
1354         #[cfg(windows)]
1355         let vm_memory_client = self.vm_memory_client.take().unwrap();
1356         let release_memory_tube = self.release_memory_tube.take();
1357         #[cfg(feature = "registered_events")]
1358         let registered_evt_q = self.registered_evt_q.take();
1359         let pending_adjusted_response_event = self
1360             .pending_adjusted_response_event
1361             .try_clone()
1362             .context("failed to clone Event")?;
1363 
1364         self.worker_thread = Some(WorkerThread::start("v_balloon", move |kill_evt| {
1365             run_worker(
1366                 queues.inflate,
1367                 queues.deflate,
1368                 queues.stats,
1369                 queues.reporting,
1370                 queues.ws_data,
1371                 queues.ws_op,
1372                 command_tube,
1373                 #[cfg(windows)]
1374                 vm_memory_client,
1375                 release_memory_tube,
1376                 interrupt,
1377                 kill_evt,
1378                 target_reached_evt,
1379                 pending_adjusted_response_event,
1380                 mem,
1381                 state,
1382                 #[cfg(feature = "registered_events")]
1383                 registered_evt_q,
1384             )
1385         }));
1386 
1387         Ok(())
1388     }
1389 }
1390 
1391 impl VirtioDevice for Balloon {
keep_rds(&self) -> Vec<RawDescriptor>1392     fn keep_rds(&self) -> Vec<RawDescriptor> {
1393         let mut rds = Vec::new();
1394         if let Some(command_tube) = &self.command_tube {
1395             rds.push(command_tube.as_raw_descriptor());
1396         }
1397         if let Some(release_memory_tube) = &self.release_memory_tube {
1398             rds.push(release_memory_tube.as_raw_descriptor());
1399         }
1400         #[cfg(feature = "registered_events")]
1401         if let Some(registered_evt_q) = &self.registered_evt_q {
1402             rds.push(registered_evt_q.as_raw_descriptor());
1403         }
1404         rds.push(self.pending_adjusted_response_event.as_raw_descriptor());
1405         rds
1406     }
1407 
device_type(&self) -> DeviceType1408     fn device_type(&self) -> DeviceType {
1409         DeviceType::Balloon
1410     }
1411 
queue_max_sizes(&self) -> &[u16]1412     fn queue_max_sizes(&self) -> &[u16] {
1413         QUEUE_SIZES
1414     }
1415 
read_config(&self, offset: u64, data: &mut [u8])1416     fn read_config(&self, offset: u64, data: &mut [u8]) {
1417         copy_config(data, 0, self.get_config().as_bytes(), offset);
1418     }
1419 
write_config(&mut self, offset: u64, data: &[u8])1420     fn write_config(&mut self, offset: u64, data: &[u8]) {
1421         let mut config = self.get_config();
1422         copy_config(config.as_bytes_mut(), offset, data, 0);
1423         let mut state = block_on(self.state.lock());
1424         state.actual_pages = config.actual.to_native();
1425 
1426         // If balloon has updated to the requested memory, let the hypervisor know.
1427         if config.num_pages == config.actual {
1428             debug!(
1429                 "sending target reached event at {}",
1430                 u32::from(config.num_pages)
1431             );
1432             self.target_reached_evt.as_ref().map(|e| e.signal());
1433         }
1434         if state.failable_update && state.actual_pages == state.num_pages {
1435             state.failable_update = false;
1436             let num_pages = state.num_pages;
1437             state.pending_adjusted_responses.push_back(num_pages);
1438             let _ = self.pending_adjusted_response_event.signal();
1439         }
1440     }
1441 
features(&self) -> u641442     fn features(&self) -> u64 {
1443         self.features
1444     }
1445 
ack_features(&mut self, mut value: u64)1446     fn ack_features(&mut self, mut value: u64) {
1447         if value & !self.features != 0 {
1448             warn!("virtio_balloon got unknown feature ack {:x}", value);
1449             value &= self.features;
1450         }
1451         self.acked_features |= value;
1452     }
1453 
activate( &mut self, mem: GuestMemory, interrupt: Interrupt, queues: BTreeMap<usize, Queue>, ) -> anyhow::Result<()>1454     fn activate(
1455         &mut self,
1456         mem: GuestMemory,
1457         interrupt: Interrupt,
1458         queues: BTreeMap<usize, Queue>,
1459     ) -> anyhow::Result<()> {
1460         let queues = self.get_queues_from_map(queues)?;
1461         self.start_worker(mem, interrupt, queues)
1462     }
1463 
reset(&mut self) -> anyhow::Result<()>1464     fn reset(&mut self) -> anyhow::Result<()> {
1465         let _worker = self.stop_worker();
1466         Ok(())
1467     }
1468 
virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>>1469     fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
1470         match self.stop_worker() {
1471             StoppedWorker::WithQueues(paused_queues) => Ok(Some(paused_queues.into())),
1472             StoppedWorker::MissingQueues => {
1473                 anyhow::bail!("balloon queue workers did not stop cleanly.")
1474             }
1475             StoppedWorker::AlreadyStopped => {
1476                 // Device hasn't been activated.
1477                 Ok(None)
1478             }
1479         }
1480     }
1481 
virtio_wake( &mut self, queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>, ) -> anyhow::Result<()>1482     fn virtio_wake(
1483         &mut self,
1484         queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
1485     ) -> anyhow::Result<()> {
1486         if let Some((mem, interrupt, queues)) = queues_state {
1487             if queues.len() < 2 {
1488                 anyhow::bail!("{} queues were found, but an activated balloon must have at least 2 active queues.", queues.len());
1489             }
1490 
1491             let balloon_queues = self.get_queues_from_map(queues)?;
1492             self.start_worker(mem, interrupt, balloon_queues)?;
1493         }
1494         Ok(())
1495     }
1496 
virtio_snapshot(&mut self) -> anyhow::Result<serde_json::Value>1497     fn virtio_snapshot(&mut self) -> anyhow::Result<serde_json::Value> {
1498         let state = self
1499             .state
1500             .lock()
1501             .now_or_never()
1502             .context("failed to acquire balloon lock")?;
1503         serde_json::to_value(BalloonSnapshot {
1504             features: self.features,
1505             acked_features: self.acked_features,
1506             state: state.clone(),
1507             ws_num_bins: self.ws_num_bins,
1508         })
1509         .context("failed to serialize balloon state")
1510     }
1511 
virtio_restore(&mut self, data: serde_json::Value) -> anyhow::Result<()>1512     fn virtio_restore(&mut self, data: serde_json::Value) -> anyhow::Result<()> {
1513         let snap: BalloonSnapshot = serde_json::from_value(data).context("error deserializing")?;
1514         if snap.features != self.features {
1515             anyhow::bail!(
1516                 "balloon: expected features to match, but they did not. Live: {:?}, snapshot {:?}",
1517                 self.features,
1518                 snap.features,
1519             );
1520         }
1521 
1522         let mut state = self
1523             .state
1524             .lock()
1525             .now_or_never()
1526             .context("failed to acquire balloon lock")?;
1527         *state = snap.state;
1528         self.ws_num_bins = snap.ws_num_bins;
1529         self.acked_features = snap.acked_features;
1530         Ok(())
1531     }
1532 }
1533 
1534 #[cfg(test)]
1535 mod tests {
1536     use super::*;
1537     use crate::suspendable_virtio_tests;
1538     use crate::virtio::descriptor_utils::create_descriptor_chain;
1539     use crate::virtio::descriptor_utils::DescriptorType;
1540 
1541     #[test]
desc_parsing_inflate()1542     fn desc_parsing_inflate() {
1543         // Check that the memory addresses are parsed correctly by 'handle_address_chain' and passed
1544         // to the closure.
1545         let memory_start_addr = GuestAddress(0x0);
1546         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1547         memory
1548             .write_obj_at_addr(0x10u32, GuestAddress(0x100))
1549             .unwrap();
1550         memory
1551             .write_obj_at_addr(0xaa55aa55u32, GuestAddress(0x104))
1552             .unwrap();
1553 
1554         let mut chain = create_descriptor_chain(
1555             &memory,
1556             GuestAddress(0x0),
1557             GuestAddress(0x100),
1558             vec![(DescriptorType::Readable, 8)],
1559             0,
1560         )
1561         .expect("create_descriptor_chain failed");
1562 
1563         let mut addrs = Vec::new();
1564         let res = handle_address_chain(None, &mut chain, &mut |guest_address, len| {
1565             addrs.push((guest_address, len));
1566         });
1567         assert!(res.is_ok());
1568         assert_eq!(addrs.len(), 2);
1569         assert_eq!(
1570             addrs[0].0,
1571             GuestAddress(0x10u64 << VIRTIO_BALLOON_PFN_SHIFT)
1572         );
1573         assert_eq!(
1574             addrs[1].0,
1575             GuestAddress(0xaa55aa55u64 << VIRTIO_BALLOON_PFN_SHIFT)
1576         );
1577     }
1578 
1579     struct BalloonContext {
1580         _ctrl_tube: Tube,
1581         #[cfg(windows)]
1582         _mem_client_tube: Tube,
1583     }
1584 
modify_device(_balloon_context: &mut BalloonContext, balloon: &mut Balloon)1585     fn modify_device(_balloon_context: &mut BalloonContext, balloon: &mut Balloon) {
1586         balloon.ws_num_bins = !balloon.ws_num_bins;
1587     }
1588 
create_device() -> (BalloonContext, Balloon)1589     fn create_device() -> (BalloonContext, Balloon) {
1590         let (_ctrl_tube, ctrl_tube_device) = Tube::pair().unwrap();
1591         #[cfg(windows)]
1592         let (_mem_client_tube, mem_client_tube_device) = Tube::pair().unwrap();
1593         (
1594             BalloonContext {
1595                 _ctrl_tube,
1596                 #[cfg(windows)]
1597                 _mem_client_tube,
1598             },
1599             Balloon::new(
1600                 0,
1601                 ctrl_tube_device,
1602                 #[cfg(windows)]
1603                 VmMemoryClient::new(mem_client_tube_device),
1604                 None,
1605                 1024,
1606                 0,
1607                 #[cfg(feature = "registered_events")]
1608                 None,
1609                 0,
1610             )
1611             .unwrap(),
1612         )
1613     }
1614 
1615     suspendable_virtio_tests!(balloon, create_device, 2, modify_device);
1616 }
1617