1 // Copyright 2017 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 mod sys;
6
7 use std::collections::BTreeMap;
8 use std::collections::VecDeque;
9 use std::io::Write;
10 use std::sync::Arc;
11
12 use anyhow::anyhow;
13 use anyhow::Context;
14 use balloon_control::BalloonStats;
15 use balloon_control::BalloonTubeCommand;
16 use balloon_control::BalloonTubeResult;
17 use balloon_control::BalloonWS;
18 use balloon_control::WSBucket;
19 use balloon_control::VIRTIO_BALLOON_WS_MAX_NUM_BINS;
20 use balloon_control::VIRTIO_BALLOON_WS_MIN_NUM_BINS;
21 use base::debug;
22 use base::error;
23 use base::warn;
24 use base::AsRawDescriptor;
25 use base::Event;
26 use base::RawDescriptor;
27 #[cfg(feature = "registered_events")]
28 use base::SendTube;
29 use base::Tube;
30 use base::WorkerThread;
31 use cros_async::block_on;
32 use cros_async::sync::RwLock as AsyncRwLock;
33 use cros_async::AsyncTube;
34 use cros_async::EventAsync;
35 use cros_async::Executor;
36 #[cfg(feature = "registered_events")]
37 use cros_async::SendTubeAsync;
38 use data_model::Le16;
39 use data_model::Le32;
40 use data_model::Le64;
41 use futures::channel::mpsc;
42 use futures::channel::oneshot;
43 use futures::pin_mut;
44 use futures::select;
45 use futures::select_biased;
46 use futures::FutureExt;
47 use futures::StreamExt;
48 use remain::sorted;
49 use serde::Deserialize;
50 use serde::Serialize;
51 use thiserror::Error as ThisError;
52 #[cfg(windows)]
53 use vm_control::api::VmMemoryClient;
54 #[cfg(feature = "registered_events")]
55 use vm_control::RegisteredEventWithData;
56 use vm_memory::GuestAddress;
57 use vm_memory::GuestMemory;
58 use zerocopy::AsBytes;
59 use zerocopy::FromBytes;
60 use zerocopy::FromZeroes;
61
62 use super::async_utils;
63 use super::copy_config;
64 use super::create_stop_oneshot;
65 use super::DescriptorChain;
66 use super::DeviceType;
67 use super::Interrupt;
68 use super::Queue;
69 use super::Reader;
70 use super::StoppedWorker;
71 use super::VirtioDevice;
72 use crate::UnpinRequest;
73 use crate::UnpinResponse;
74
75 #[sorted]
76 #[derive(ThisError, Debug)]
77 pub enum BalloonError {
78 /// Failed an async await
79 #[error("failed async await: {0}")]
80 AsyncAwait(cros_async::AsyncError),
81 /// Failed an async await
82 #[error("failed async await: {0}")]
83 AsyncAwaitAnyhow(anyhow::Error),
84 /// Failed to create event.
85 #[error("failed to create event: {0}")]
86 CreatingEvent(base::Error),
87 /// Failed to create async message receiver.
88 #[error("failed to create async message receiver: {0}")]
89 CreatingMessageReceiver(base::TubeError),
90 /// Failed to receive command message.
91 #[error("failed to receive command message: {0}")]
92 ReceivingCommand(base::TubeError),
93 /// Failed to send command response.
94 #[error("failed to send command response: {0}")]
95 SendResponse(base::TubeError),
96 /// Error while writing to virtqueue
97 #[error("failed to write to virtqueue: {0}")]
98 WriteQueue(std::io::Error),
99 /// Failed to write config event.
100 #[error("failed to write config event: {0}")]
101 WritingConfigEvent(base::Error),
102 }
103 pub type Result<T> = std::result::Result<T, BalloonError>;
104
105 // Balloon implements five virt IO queues: Inflate, Deflate, Stats, WsData, WsCmd.
106 const QUEUE_SIZE: u16 = 128;
107 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE];
108
109 // Virtqueue indexes
110 const INFLATEQ: usize = 0;
111 const DEFLATEQ: usize = 1;
112 const STATSQ: usize = 2;
113 const _FREE_PAGE_VQ: usize = 3;
114 const REPORTING_VQ: usize = 4;
115 const WS_DATA_VQ: usize = 5;
116 const WS_OP_VQ: usize = 6;
117
118 const VIRTIO_BALLOON_PFN_SHIFT: u32 = 12;
119 const VIRTIO_BALLOON_PF_SIZE: u64 = 1 << VIRTIO_BALLOON_PFN_SHIFT;
120
121 // The feature bitmap for virtio balloon
122 const VIRTIO_BALLOON_F_MUST_TELL_HOST: u32 = 0; // Tell before reclaiming pages
123 const VIRTIO_BALLOON_F_STATS_VQ: u32 = 1; // Stats reporting enabled
124 const VIRTIO_BALLOON_F_DEFLATE_ON_OOM: u32 = 2; // Deflate balloon on OOM
125 const VIRTIO_BALLOON_F_PAGE_REPORTING: u32 = 5; // Page reporting virtqueue
126 // TODO(b/273973298): this should maybe be bit 6? to be changed later
127 const VIRTIO_BALLOON_F_WS_REPORTING: u32 = 8; // Working Set Reporting virtqueues
128
129 #[derive(Copy, Clone)]
130 #[repr(u32)]
131 // Balloon virtqueues
132 pub enum BalloonFeatures {
133 // Page Reporting enabled
134 PageReporting = VIRTIO_BALLOON_F_PAGE_REPORTING,
135 // WS Reporting enabled
136 WSReporting = VIRTIO_BALLOON_F_WS_REPORTING,
137 }
138
139 // virtio_balloon_config is the balloon device configuration space defined by the virtio spec.
140 #[derive(Copy, Clone, Debug, Default, AsBytes, FromZeroes, FromBytes)]
141 #[repr(C)]
142 struct virtio_balloon_config {
143 num_pages: Le32,
144 actual: Le32,
145 free_page_hint_cmd_id: Le32,
146 poison_val: Le32,
147 // WS field is part of proposed spec extension (b/273973298).
148 ws_num_bins: u8,
149 _reserved: [u8; 3],
150 }
151
152 // BalloonState is shared by the worker and device thread.
153 #[derive(Clone, Default, Serialize, Deserialize)]
154 struct BalloonState {
155 num_pages: u32,
156 actual_pages: u32,
157 expecting_ws: bool,
158 // Flag indicating that the balloon is in the process of a failable update. This
159 // is set by an Adjust command that has allow_failure set, and is cleared when the
160 // Adjusted success/failure response is sent.
161 failable_update: bool,
162 pending_adjusted_responses: VecDeque<u32>,
163 }
164
165 // The constants defining stats types in virtio_baloon_stat
166 const VIRTIO_BALLOON_S_SWAP_IN: u16 = 0;
167 const VIRTIO_BALLOON_S_SWAP_OUT: u16 = 1;
168 const VIRTIO_BALLOON_S_MAJFLT: u16 = 2;
169 const VIRTIO_BALLOON_S_MINFLT: u16 = 3;
170 const VIRTIO_BALLOON_S_MEMFREE: u16 = 4;
171 const VIRTIO_BALLOON_S_MEMTOT: u16 = 5;
172 const VIRTIO_BALLOON_S_AVAIL: u16 = 6;
173 const VIRTIO_BALLOON_S_CACHES: u16 = 7;
174 const VIRTIO_BALLOON_S_HTLB_PGALLOC: u16 = 8;
175 const VIRTIO_BALLOON_S_HTLB_PGFAIL: u16 = 9;
176 const VIRTIO_BALLOON_S_NONSTANDARD_SHMEM: u16 = 65534;
177 const VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE: u16 = 65535;
178
179 // BalloonStat is used to deserialize stats from the stats_queue.
180 #[derive(Copy, Clone, FromZeroes, FromBytes, AsBytes)]
181 #[repr(C, packed)]
182 struct BalloonStat {
183 tag: Le16,
184 val: Le64,
185 }
186
187 impl BalloonStat {
update_stats(&self, stats: &mut BalloonStats)188 fn update_stats(&self, stats: &mut BalloonStats) {
189 let val = Some(self.val.to_native());
190 match self.tag.to_native() {
191 VIRTIO_BALLOON_S_SWAP_IN => stats.swap_in = val,
192 VIRTIO_BALLOON_S_SWAP_OUT => stats.swap_out = val,
193 VIRTIO_BALLOON_S_MAJFLT => stats.major_faults = val,
194 VIRTIO_BALLOON_S_MINFLT => stats.minor_faults = val,
195 VIRTIO_BALLOON_S_MEMFREE => stats.free_memory = val,
196 VIRTIO_BALLOON_S_MEMTOT => stats.total_memory = val,
197 VIRTIO_BALLOON_S_AVAIL => stats.available_memory = val,
198 VIRTIO_BALLOON_S_CACHES => stats.disk_caches = val,
199 VIRTIO_BALLOON_S_HTLB_PGALLOC => stats.hugetlb_allocations = val,
200 VIRTIO_BALLOON_S_HTLB_PGFAIL => stats.hugetlb_failures = val,
201 VIRTIO_BALLOON_S_NONSTANDARD_SHMEM => stats.shared_memory = val,
202 VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE => stats.unevictable_memory = val,
203 _ => (),
204 }
205 }
206 }
207
208 // virtio_balloon_ws is used to deserialize from the ws data vq.
209 #[repr(C)]
210 #[derive(Copy, Clone, Debug, Default, AsBytes, FromZeroes, FromBytes)]
211 struct virtio_balloon_ws {
212 tag: Le16,
213 node_id: Le16,
214 // virtio prefers field members to align on a word boundary so we must pad. see:
215 // https://crsrc.org/o/src/third_party/kernel/v5.15/include/uapi/linux/virtio_balloon.h;l=105
216 _reserved: [u8; 4],
217 idle_age_ms: Le64,
218 // TODO(b/273973298): these should become separate fields - bytes for ANON and FILE
219 memory_size_bytes: [Le64; 2],
220 }
221
222 impl virtio_balloon_ws {
update_ws(&self, ws: &mut BalloonWS)223 fn update_ws(&self, ws: &mut BalloonWS) {
224 let bucket = WSBucket {
225 age: self.idle_age_ms.to_native(),
226 bytes: [
227 self.memory_size_bytes[0].to_native(),
228 self.memory_size_bytes[1].to_native(),
229 ],
230 };
231 ws.ws.push(bucket);
232 }
233 }
234
235 const _VIRTIO_BALLOON_WS_OP_INVALID: u16 = 0;
236 const VIRTIO_BALLOON_WS_OP_REQUEST: u16 = 1;
237 const VIRTIO_BALLOON_WS_OP_CONFIG: u16 = 2;
238 const _VIRTIO_BALLOON_WS_OP_DISCARD: u16 = 3;
239
240 // virtio_balloon_op is used to serialize to the ws cmd vq.
241 #[repr(C, packed)]
242 #[derive(Copy, Clone, Debug, Default, AsBytes, FromZeroes, FromBytes)]
243 struct virtio_balloon_op {
244 type_: Le16,
245 }
246
invoke_desc_handler<F>(ranges: Vec<(u64, u64)>, desc_handler: &mut F) where F: FnMut(GuestAddress, u64),247 fn invoke_desc_handler<F>(ranges: Vec<(u64, u64)>, desc_handler: &mut F)
248 where
249 F: FnMut(GuestAddress, u64),
250 {
251 for range in ranges {
252 desc_handler(GuestAddress(range.0), range.1);
253 }
254 }
255
256 // Release a list of guest memory ranges back to the host system.
257 // Unpin requests for each inflate range will be sent via `release_memory_tube`
258 // if provided, and then `desc_handler` will be called for each inflate range.
release_ranges<F>( release_memory_tube: Option<&Tube>, inflate_ranges: Vec<(u64, u64)>, desc_handler: &mut F, ) -> anyhow::Result<()> where F: FnMut(GuestAddress, u64),259 fn release_ranges<F>(
260 release_memory_tube: Option<&Tube>,
261 inflate_ranges: Vec<(u64, u64)>,
262 desc_handler: &mut F,
263 ) -> anyhow::Result<()>
264 where
265 F: FnMut(GuestAddress, u64),
266 {
267 if let Some(tube) = release_memory_tube {
268 let unpin_ranges = inflate_ranges
269 .iter()
270 .map(|v| {
271 (
272 v.0 >> VIRTIO_BALLOON_PFN_SHIFT,
273 v.1 / VIRTIO_BALLOON_PF_SIZE,
274 )
275 })
276 .collect();
277 let req = UnpinRequest {
278 ranges: unpin_ranges,
279 };
280 if let Err(e) = tube.send(&req) {
281 error!("failed to send unpin request: {}", e);
282 } else {
283 match tube.recv() {
284 Ok(resp) => match resp {
285 UnpinResponse::Success => invoke_desc_handler(inflate_ranges, desc_handler),
286 UnpinResponse::Failed => error!("failed to handle unpin request"),
287 },
288 Err(e) => error!("failed to handle get unpin response: {}", e),
289 }
290 }
291 } else {
292 invoke_desc_handler(inflate_ranges, desc_handler);
293 }
294
295 Ok(())
296 }
297
298 // Processes one message's list of addresses.
handle_address_chain<F>( release_memory_tube: Option<&Tube>, avail_desc: &mut DescriptorChain, desc_handler: &mut F, ) -> anyhow::Result<()> where F: FnMut(GuestAddress, u64),299 fn handle_address_chain<F>(
300 release_memory_tube: Option<&Tube>,
301 avail_desc: &mut DescriptorChain,
302 desc_handler: &mut F,
303 ) -> anyhow::Result<()>
304 where
305 F: FnMut(GuestAddress, u64),
306 {
307 // In a long-running system, there is no reason to expect that
308 // a significant number of freed pages are consecutive. However,
309 // batching is relatively simple and can result in significant
310 // gains in a newly booted system, so it's worth attempting.
311 let mut range_start = 0;
312 let mut range_size = 0;
313 let mut inflate_ranges: Vec<(u64, u64)> = Vec::new();
314 for res in avail_desc.reader.iter::<Le32>() {
315 let pfn = match res {
316 Ok(pfn) => pfn,
317 Err(e) => {
318 error!("error while reading unused pages: {}", e);
319 break;
320 }
321 };
322 let guest_address = (u64::from(pfn.to_native())) << VIRTIO_BALLOON_PFN_SHIFT;
323 if range_start + range_size == guest_address {
324 range_size += VIRTIO_BALLOON_PF_SIZE;
325 } else if range_start == guest_address + VIRTIO_BALLOON_PF_SIZE {
326 range_start = guest_address;
327 range_size += VIRTIO_BALLOON_PF_SIZE;
328 } else {
329 // Discontinuity, so flush the previous range. Note range_size
330 // will be 0 on the first iteration, so skip that.
331 if range_size != 0 {
332 inflate_ranges.push((range_start, range_size));
333 }
334 range_start = guest_address;
335 range_size = VIRTIO_BALLOON_PF_SIZE;
336 }
337 }
338 if range_size != 0 {
339 inflate_ranges.push((range_start, range_size));
340 }
341
342 release_ranges(release_memory_tube, inflate_ranges, desc_handler)
343 }
344
345 // Async task that handles the main balloon inflate and deflate queues.
handle_queue<F>( mut queue: Queue, mut queue_event: EventAsync, release_memory_tube: Option<&Tube>, mut desc_handler: F, mut stop_rx: oneshot::Receiver<()>, ) -> Queue where F: FnMut(GuestAddress, u64),346 async fn handle_queue<F>(
347 mut queue: Queue,
348 mut queue_event: EventAsync,
349 release_memory_tube: Option<&Tube>,
350 mut desc_handler: F,
351 mut stop_rx: oneshot::Receiver<()>,
352 ) -> Queue
353 where
354 F: FnMut(GuestAddress, u64),
355 {
356 loop {
357 let mut avail_desc = match queue
358 .next_async_interruptable(&mut queue_event, &mut stop_rx)
359 .await
360 {
361 Ok(Some(res)) => res,
362 Ok(None) => return queue,
363 Err(e) => {
364 error!("Failed to read descriptor {}", e);
365 return queue;
366 }
367 };
368 if let Err(e) =
369 handle_address_chain(release_memory_tube, &mut avail_desc, &mut desc_handler)
370 {
371 error!("balloon: failed to process inflate addresses: {}", e);
372 }
373 queue.add_used(avail_desc, 0);
374 queue.trigger_interrupt();
375 }
376 }
377
378 // Processes one page-reporting descriptor.
handle_reported_buffer<F>( release_memory_tube: Option<&Tube>, avail_desc: &DescriptorChain, desc_handler: &mut F, ) -> anyhow::Result<()> where F: FnMut(GuestAddress, u64),379 fn handle_reported_buffer<F>(
380 release_memory_tube: Option<&Tube>,
381 avail_desc: &DescriptorChain,
382 desc_handler: &mut F,
383 ) -> anyhow::Result<()>
384 where
385 F: FnMut(GuestAddress, u64),
386 {
387 let reported_ranges: Vec<(u64, u64)> = avail_desc
388 .reader
389 .get_remaining_regions()
390 .chain(avail_desc.writer.get_remaining_regions())
391 .map(|r| (r.offset, r.len as u64))
392 .collect();
393
394 release_ranges(release_memory_tube, reported_ranges, desc_handler)
395 }
396
397 // Async task that handles the page reporting queue.
handle_reporting_queue<F>( mut queue: Queue, mut queue_event: EventAsync, release_memory_tube: Option<&Tube>, mut desc_handler: F, mut stop_rx: oneshot::Receiver<()>, ) -> Queue where F: FnMut(GuestAddress, u64),398 async fn handle_reporting_queue<F>(
399 mut queue: Queue,
400 mut queue_event: EventAsync,
401 release_memory_tube: Option<&Tube>,
402 mut desc_handler: F,
403 mut stop_rx: oneshot::Receiver<()>,
404 ) -> Queue
405 where
406 F: FnMut(GuestAddress, u64),
407 {
408 loop {
409 let avail_desc = match queue
410 .next_async_interruptable(&mut queue_event, &mut stop_rx)
411 .await
412 {
413 Ok(Some(res)) => res,
414 Ok(None) => return queue,
415 Err(e) => {
416 error!("Failed to read descriptor {}", e);
417 return queue;
418 }
419 };
420 if let Err(e) = handle_reported_buffer(release_memory_tube, &avail_desc, &mut desc_handler)
421 {
422 error!("balloon: failed to process reported buffer: {}", e);
423 }
424 queue.add_used(avail_desc, 0);
425 queue.trigger_interrupt();
426 }
427 }
428
parse_balloon_stats(reader: &mut Reader) -> BalloonStats429 fn parse_balloon_stats(reader: &mut Reader) -> BalloonStats {
430 let mut stats: BalloonStats = Default::default();
431 for res in reader.iter::<BalloonStat>() {
432 match res {
433 Ok(stat) => stat.update_stats(&mut stats),
434 Err(e) => {
435 error!("error while reading stats: {}", e);
436 break;
437 }
438 };
439 }
440 stats
441 }
442
443 // Async task that handles the stats queue. Note that the cadence of this is driven by requests for
444 // balloon stats from the control pipe.
445 // The guests queues an initial buffer on boot, which is read and then this future will block until
446 // signaled from the command socket that stats should be collected again.
handle_stats_queue( mut queue: Queue, mut queue_event: EventAsync, mut stats_rx: mpsc::Receiver<()>, command_tube: &AsyncTube, #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>, state: Arc<AsyncRwLock<BalloonState>>, mut stop_rx: oneshot::Receiver<()>, ) -> Queue447 async fn handle_stats_queue(
448 mut queue: Queue,
449 mut queue_event: EventAsync,
450 mut stats_rx: mpsc::Receiver<()>,
451 command_tube: &AsyncTube,
452 #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
453 state: Arc<AsyncRwLock<BalloonState>>,
454 mut stop_rx: oneshot::Receiver<()>,
455 ) -> Queue {
456 let mut avail_desc = match queue
457 .next_async_interruptable(&mut queue_event, &mut stop_rx)
458 .await
459 {
460 // Consume the first stats buffer sent from the guest at startup. It was not
461 // requested by anyone, and the stats are stale.
462 Ok(Some(res)) => res,
463 Ok(None) => return queue,
464 Err(e) => {
465 error!("Failed to read descriptor {}", e);
466 return queue;
467 }
468 };
469
470 loop {
471 select_biased! {
472 msg = stats_rx.next() => {
473 // Wait for a request to read the stats.
474 match msg {
475 Some(()) => (),
476 None => {
477 error!("stats signal channel was closed");
478 return queue;
479 }
480 }
481 }
482 _ = stop_rx => return queue,
483 };
484
485 // Request a new stats_desc to the guest.
486 queue.add_used(avail_desc, 0);
487 queue.trigger_interrupt();
488
489 avail_desc = match queue.next_async(&mut queue_event).await {
490 Err(e) => {
491 error!("Failed to read descriptor {}", e);
492 return queue;
493 }
494 Ok(d) => d,
495 };
496 let stats = parse_balloon_stats(&mut avail_desc.reader);
497
498 let actual_pages = state.lock().await.actual_pages as u64;
499 let result = BalloonTubeResult::Stats {
500 balloon_actual: actual_pages << VIRTIO_BALLOON_PFN_SHIFT,
501 stats,
502 };
503 let send_result = command_tube.send(result).await;
504 if let Err(e) = send_result {
505 error!("failed to send stats result: {}", e);
506 }
507
508 #[cfg(feature = "registered_events")]
509 if let Some(registered_evt_q) = registered_evt_q {
510 if let Err(e) = registered_evt_q
511 .send(&RegisteredEventWithData::VirtioBalloonResize)
512 .await
513 {
514 error!("failed to send VirtioBalloonResize event: {}", e);
515 }
516 }
517 }
518 }
519
send_adjusted_response( tube: &AsyncTube, num_pages: u32, ) -> std::result::Result<(), base::TubeError>520 async fn send_adjusted_response(
521 tube: &AsyncTube,
522 num_pages: u32,
523 ) -> std::result::Result<(), base::TubeError> {
524 let num_bytes = (num_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
525 let result = BalloonTubeResult::Adjusted { num_bytes };
526 tube.send(result).await
527 }
528
529 enum WSOp {
530 WSReport,
531 WSConfig {
532 bins: Vec<u32>,
533 refresh_threshold: u32,
534 report_threshold: u32,
535 },
536 }
537
handle_ws_op_queue( mut queue: Queue, mut queue_event: EventAsync, mut ws_op_rx: mpsc::Receiver<WSOp>, state: Arc<AsyncRwLock<BalloonState>>, mut stop_rx: oneshot::Receiver<()>, ) -> Result<Queue>538 async fn handle_ws_op_queue(
539 mut queue: Queue,
540 mut queue_event: EventAsync,
541 mut ws_op_rx: mpsc::Receiver<WSOp>,
542 state: Arc<AsyncRwLock<BalloonState>>,
543 mut stop_rx: oneshot::Receiver<()>,
544 ) -> Result<Queue> {
545 loop {
546 let op = select_biased! {
547 next_op = ws_op_rx.next().fuse() => {
548 match next_op {
549 Some(op) => op,
550 None => {
551 error!("ws op tube was closed");
552 break;
553 }
554 }
555 }
556 _ = stop_rx => {
557 break;
558 }
559 };
560 let mut avail_desc = queue
561 .next_async(&mut queue_event)
562 .await
563 .map_err(BalloonError::AsyncAwait)?;
564 let writer = &mut avail_desc.writer;
565
566 match op {
567 WSOp::WSReport => {
568 {
569 let mut state = state.lock().await;
570 state.expecting_ws = true;
571 }
572
573 let ws_r = virtio_balloon_op {
574 type_: VIRTIO_BALLOON_WS_OP_REQUEST.into(),
575 };
576
577 writer.write_obj(ws_r).map_err(BalloonError::WriteQueue)?;
578 }
579 WSOp::WSConfig {
580 bins,
581 refresh_threshold,
582 report_threshold,
583 } => {
584 let cmd = virtio_balloon_op {
585 type_: VIRTIO_BALLOON_WS_OP_CONFIG.into(),
586 };
587
588 writer.write_obj(cmd).map_err(BalloonError::WriteQueue)?;
589 writer
590 .write_all(bins.as_bytes())
591 .map_err(BalloonError::WriteQueue)?;
592 writer
593 .write_obj(refresh_threshold)
594 .map_err(BalloonError::WriteQueue)?;
595 writer
596 .write_obj(report_threshold)
597 .map_err(BalloonError::WriteQueue)?;
598 }
599 }
600
601 let len = writer.bytes_written() as u32;
602 queue.add_used(avail_desc, len);
603 queue.trigger_interrupt();
604 }
605
606 Ok(queue)
607 }
608
parse_balloon_ws(reader: &mut Reader) -> BalloonWS609 fn parse_balloon_ws(reader: &mut Reader) -> BalloonWS {
610 let mut ws = BalloonWS::new();
611 for res in reader.iter::<virtio_balloon_ws>() {
612 match res {
613 Ok(ws_msg) => {
614 ws_msg.update_ws(&mut ws);
615 }
616 Err(e) => {
617 error!("error while reading ws: {}", e);
618 break;
619 }
620 }
621 }
622 if ws.ws.len() < VIRTIO_BALLOON_WS_MIN_NUM_BINS || ws.ws.len() > VIRTIO_BALLOON_WS_MAX_NUM_BINS
623 {
624 error!("unexpected number of WS buckets: {}", ws.ws.len());
625 }
626 ws
627 }
628
629 // Async task that handles the stats queue. Note that the arrival of events on
630 // the WS vq may be the result of either a WS request (WS-R) command having
631 // been sent to the guest, or an unprompted send due to memory pressue in the
632 // guest. If the data was requested, we should also send that back on the
633 // command tube.
handle_ws_data_queue( mut queue: Queue, mut queue_event: EventAsync, command_tube: &AsyncTube, #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>, state: Arc<AsyncRwLock<BalloonState>>, mut stop_rx: oneshot::Receiver<()>, ) -> Result<Queue>634 async fn handle_ws_data_queue(
635 mut queue: Queue,
636 mut queue_event: EventAsync,
637 command_tube: &AsyncTube,
638 #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
639 state: Arc<AsyncRwLock<BalloonState>>,
640 mut stop_rx: oneshot::Receiver<()>,
641 ) -> Result<Queue> {
642 loop {
643 let mut avail_desc = match queue
644 .next_async_interruptable(&mut queue_event, &mut stop_rx)
645 .await
646 .map_err(BalloonError::AsyncAwait)?
647 {
648 Some(res) => res,
649 None => return Ok(queue),
650 };
651
652 let ws = parse_balloon_ws(&mut avail_desc.reader);
653
654 let mut state = state.lock().await;
655
656 // update ws report with balloon pages now that we have a lock on state
657 let balloon_actual = (state.actual_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
658
659 if state.expecting_ws {
660 let result = BalloonTubeResult::WorkingSet { ws, balloon_actual };
661 let send_result = command_tube.send(result).await;
662 if let Err(e) = send_result {
663 error!("failed to send ws result: {}", e);
664 }
665
666 state.expecting_ws = false;
667 } else {
668 #[cfg(feature = "registered_events")]
669 if let Some(registered_evt_q) = registered_evt_q {
670 if let Err(e) = registered_evt_q
671 .send(RegisteredEventWithData::from_ws(&ws, balloon_actual))
672 .await
673 {
674 error!("failed to send VirtioBalloonWSReport event: {}", e);
675 }
676 }
677 }
678
679 queue.add_used(avail_desc, 0);
680 queue.trigger_interrupt();
681 }
682 }
683
684 // Async task that handles the command socket. The command socket handles messages from the host
685 // requesting that the guest balloon be adjusted or to report guest memory statistics.
handle_command_tube( command_tube: &AsyncTube, interrupt: Interrupt, state: Arc<AsyncRwLock<BalloonState>>, mut stats_tx: mpsc::Sender<()>, mut ws_op_tx: mpsc::Sender<WSOp>, mut stop_rx: oneshot::Receiver<()>, ) -> Result<()>686 async fn handle_command_tube(
687 command_tube: &AsyncTube,
688 interrupt: Interrupt,
689 state: Arc<AsyncRwLock<BalloonState>>,
690 mut stats_tx: mpsc::Sender<()>,
691 mut ws_op_tx: mpsc::Sender<WSOp>,
692 mut stop_rx: oneshot::Receiver<()>,
693 ) -> Result<()> {
694 loop {
695 let cmd_res = select_biased! {
696 res = command_tube.next().fuse() => res,
697 _ = stop_rx => return Ok(())
698 };
699 match cmd_res {
700 Ok(command) => match command {
701 BalloonTubeCommand::Adjust {
702 num_bytes,
703 allow_failure,
704 } => {
705 let num_pages = (num_bytes >> VIRTIO_BALLOON_PFN_SHIFT) as u32;
706 let mut state = state.lock().await;
707
708 state.num_pages = num_pages;
709 interrupt.signal_config_changed();
710
711 if allow_failure {
712 if num_pages == state.actual_pages {
713 send_adjusted_response(command_tube, num_pages)
714 .await
715 .map_err(BalloonError::SendResponse)?;
716 } else {
717 state.failable_update = true;
718 }
719 }
720 }
721 BalloonTubeCommand::WorkingSetConfig {
722 bins,
723 refresh_threshold,
724 report_threshold,
725 } => {
726 if let Err(e) = ws_op_tx.try_send(WSOp::WSConfig {
727 bins,
728 refresh_threshold,
729 report_threshold,
730 }) {
731 error!("failed to send config to ws handler: {}", e);
732 }
733 }
734 BalloonTubeCommand::Stats => {
735 if let Err(e) = stats_tx.try_send(()) {
736 error!("failed to signal the stat handler: {}", e);
737 }
738 }
739 BalloonTubeCommand::WorkingSet => {
740 if let Err(e) = ws_op_tx.try_send(WSOp::WSReport) {
741 error!("failed to send report request to ws handler: {}", e);
742 }
743 }
744 },
745 #[cfg(windows)]
746 Err(base::TubeError::Recv(e)) if e.kind() == std::io::ErrorKind::TimedOut => {
747 // On Windows, async IO tasks like the next/recv above are cancelled as the VM is
748 // shutting down. For the sake of consistency with unix, we can't *just* return
749 // here; instead, we wait for the stop request to arrive, *and then* return.
750 //
751 // The real fix is to get rid of the global unblock pool, since then we won't
752 // cancel the tasks early (b/196911556).
753 let _ = stop_rx.await;
754 return Ok(());
755 }
756 Err(e) => {
757 return Err(BalloonError::ReceivingCommand(e));
758 }
759 }
760 }
761 }
762
handle_pending_adjusted_responses( pending_adjusted_response_event: EventAsync, command_tube: &AsyncTube, state: Arc<AsyncRwLock<BalloonState>>, ) -> Result<()>763 async fn handle_pending_adjusted_responses(
764 pending_adjusted_response_event: EventAsync,
765 command_tube: &AsyncTube,
766 state: Arc<AsyncRwLock<BalloonState>>,
767 ) -> Result<()> {
768 loop {
769 pending_adjusted_response_event
770 .next_val()
771 .await
772 .map_err(BalloonError::AsyncAwait)?;
773 while let Some(num_pages) = state.lock().await.pending_adjusted_responses.pop_front() {
774 send_adjusted_response(command_tube, num_pages)
775 .await
776 .map_err(BalloonError::SendResponse)?;
777 }
778 }
779 }
780
781 /// Represents queues & events for the balloon device.
782 struct BalloonQueues {
783 inflate: Queue,
784 deflate: Queue,
785 stats: Option<Queue>,
786 reporting: Option<Queue>,
787 ws_data: Option<Queue>,
788 ws_op: Option<Queue>,
789 }
790
791 impl BalloonQueues {
new(inflate: Queue, deflate: Queue) -> Self792 fn new(inflate: Queue, deflate: Queue) -> Self {
793 BalloonQueues {
794 inflate,
795 deflate,
796 stats: None,
797 reporting: None,
798 ws_data: None,
799 ws_op: None,
800 }
801 }
802 }
803
804 /// When the worker is stopped, the queues are preserved here.
805 struct PausedQueues {
806 inflate: Queue,
807 deflate: Queue,
808 stats: Option<Queue>,
809 reporting: Option<Queue>,
810 ws_data: Option<Queue>,
811 ws_op: Option<Queue>,
812 }
813
814 impl PausedQueues {
new(inflate: Queue, deflate: Queue) -> Self815 fn new(inflate: Queue, deflate: Queue) -> Self {
816 PausedQueues {
817 inflate,
818 deflate,
819 stats: None,
820 reporting: None,
821 ws_data: None,
822 ws_op: None,
823 }
824 }
825 }
826
apply_if_some<F, R>(queue_opt: Option<Queue>, mut func: F) where F: FnMut(Queue) -> R,827 fn apply_if_some<F, R>(queue_opt: Option<Queue>, mut func: F)
828 where
829 F: FnMut(Queue) -> R,
830 {
831 if let Some(queue) = queue_opt {
832 func(queue);
833 }
834 }
835
836 impl From<Box<PausedQueues>> for BTreeMap<usize, Queue> {
from(queues: Box<PausedQueues>) -> BTreeMap<usize, Queue>837 fn from(queues: Box<PausedQueues>) -> BTreeMap<usize, Queue> {
838 let mut ret = Vec::new();
839 ret.push(queues.inflate);
840 ret.push(queues.deflate);
841 apply_if_some(queues.stats, |stats| ret.push(stats));
842 apply_if_some(queues.reporting, |reporting| ret.push(reporting));
843 apply_if_some(queues.ws_data, |ws_data| ret.push(ws_data));
844 apply_if_some(queues.ws_op, |ws_op| ret.push(ws_op));
845 // WARNING: We don't use the indices from the virito spec on purpose, see comment in
846 // get_queues_from_map for the rationale.
847 ret.into_iter().enumerate().collect()
848 }
849 }
850
851 /// Stores data from the worker when it stops so that data can be re-used when
852 /// the worker is restarted.
853 struct WorkerReturn {
854 release_memory_tube: Option<Tube>,
855 command_tube: Tube,
856 #[cfg(feature = "registered_events")]
857 registered_evt_q: Option<SendTube>,
858 paused_queues: Option<PausedQueues>,
859 #[cfg(windows)]
860 vm_memory_client: VmMemoryClient,
861 }
862
863 // The main worker thread. Initialized the asynchronous worker tasks and passes them to the executor
864 // to be processed.
run_worker( inflate_queue: Queue, deflate_queue: Queue, stats_queue: Option<Queue>, reporting_queue: Option<Queue>, ws_data_queue: Option<Queue>, ws_op_queue: Option<Queue>, command_tube: Tube, #[cfg(windows)] vm_memory_client: VmMemoryClient, release_memory_tube: Option<Tube>, interrupt: Interrupt, kill_evt: Event, target_reached_evt: Event, pending_adjusted_response_event: Event, mem: GuestMemory, state: Arc<AsyncRwLock<BalloonState>>, #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>, ) -> WorkerReturn865 fn run_worker(
866 inflate_queue: Queue,
867 deflate_queue: Queue,
868 stats_queue: Option<Queue>,
869 reporting_queue: Option<Queue>,
870 ws_data_queue: Option<Queue>,
871 ws_op_queue: Option<Queue>,
872 command_tube: Tube,
873 #[cfg(windows)] vm_memory_client: VmMemoryClient,
874 release_memory_tube: Option<Tube>,
875 interrupt: Interrupt,
876 kill_evt: Event,
877 target_reached_evt: Event,
878 pending_adjusted_response_event: Event,
879 mem: GuestMemory,
880 state: Arc<AsyncRwLock<BalloonState>>,
881 #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
882 ) -> WorkerReturn {
883 let ex = Executor::new().unwrap();
884 let command_tube = AsyncTube::new(&ex, command_tube).unwrap();
885 #[cfg(feature = "registered_events")]
886 let registered_evt_q_async = registered_evt_q
887 .as_ref()
888 .map(|q| SendTubeAsync::new(q.try_clone().unwrap(), &ex).unwrap());
889
890 let mut stop_queue_oneshots = Vec::new();
891
892 // We need a block to release all references to command_tube at the end before returning it.
893 let paused_queues = {
894 // The first queue is used for inflate messages
895 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
896 let inflate_queue_evt = inflate_queue
897 .event()
898 .try_clone()
899 .expect("failed to clone queue event");
900 let inflate = handle_queue(
901 inflate_queue,
902 EventAsync::new(inflate_queue_evt, &ex).expect("failed to create async event"),
903 release_memory_tube.as_ref(),
904 |guest_address, len| {
905 sys::free_memory(
906 &guest_address,
907 len,
908 #[cfg(windows)]
909 &vm_memory_client,
910 #[cfg(any(target_os = "android", target_os = "linux"))]
911 &mem,
912 )
913 },
914 stop_rx,
915 );
916 let inflate = inflate.fuse();
917 pin_mut!(inflate);
918
919 // The second queue is used for deflate messages
920 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
921 let deflate_queue_evt = deflate_queue
922 .event()
923 .try_clone()
924 .expect("failed to clone queue event");
925 let deflate = handle_queue(
926 deflate_queue,
927 EventAsync::new(deflate_queue_evt, &ex).expect("failed to create async event"),
928 None,
929 |guest_address, len| {
930 sys::reclaim_memory(
931 &guest_address,
932 len,
933 #[cfg(windows)]
934 &vm_memory_client,
935 )
936 },
937 stop_rx,
938 );
939 let deflate = deflate.fuse();
940 pin_mut!(deflate);
941
942 // The next queue is used for stats messages if VIRTIO_BALLOON_F_STATS_VQ is negotiated.
943 let (stats_tx, stats_rx) = mpsc::channel::<()>(1);
944 let has_stats_queue = stats_queue.is_some();
945 let stats = if let Some(stats_queue) = stats_queue {
946 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
947 let stats_queue_evt = stats_queue
948 .event()
949 .try_clone()
950 .expect("failed to clone queue event");
951 handle_stats_queue(
952 stats_queue,
953 EventAsync::new(stats_queue_evt, &ex).expect("failed to create async event"),
954 stats_rx,
955 &command_tube,
956 #[cfg(feature = "registered_events")]
957 registered_evt_q_async.as_ref(),
958 state.clone(),
959 stop_rx,
960 )
961 .left_future()
962 } else {
963 std::future::pending().right_future()
964 };
965 let stats = stats.fuse();
966 pin_mut!(stats);
967
968 // The next queue is used for reporting messages
969 let has_reporting_queue = reporting_queue.is_some();
970 let reporting = if let Some(reporting_queue) = reporting_queue {
971 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
972 let reporting_queue_evt = reporting_queue
973 .event()
974 .try_clone()
975 .expect("failed to clone queue event");
976 handle_reporting_queue(
977 reporting_queue,
978 EventAsync::new(reporting_queue_evt, &ex).expect("failed to create async event"),
979 release_memory_tube.as_ref(),
980 |guest_address, len| {
981 sys::free_memory(
982 &guest_address,
983 len,
984 #[cfg(windows)]
985 &vm_memory_client,
986 #[cfg(any(target_os = "android", target_os = "linux"))]
987 &mem,
988 )
989 },
990 stop_rx,
991 )
992 .left_future()
993 } else {
994 std::future::pending().right_future()
995 };
996 let reporting = reporting.fuse();
997 pin_mut!(reporting);
998
999 // If VIRTIO_BALLOON_F_WS_REPORTING is set 2 queues must handled - one for WS data and one
1000 // for WS notifications.
1001 let has_ws_data_queue = ws_data_queue.is_some();
1002 let ws_data = if let Some(ws_data_queue) = ws_data_queue {
1003 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1004 let ws_data_queue_evt = ws_data_queue
1005 .event()
1006 .try_clone()
1007 .expect("failed to clone queue event");
1008 handle_ws_data_queue(
1009 ws_data_queue,
1010 EventAsync::new(ws_data_queue_evt, &ex).expect("failed to create async event"),
1011 &command_tube,
1012 #[cfg(feature = "registered_events")]
1013 registered_evt_q_async.as_ref(),
1014 state.clone(),
1015 stop_rx,
1016 )
1017 .left_future()
1018 } else {
1019 std::future::pending().right_future()
1020 };
1021 let ws_data = ws_data.fuse();
1022 pin_mut!(ws_data);
1023
1024 let (ws_op_tx, ws_op_rx) = mpsc::channel::<WSOp>(1);
1025 let has_ws_op_queue = ws_op_queue.is_some();
1026 let ws_op = if let Some(ws_op_queue) = ws_op_queue {
1027 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1028 let ws_op_queue_evt = ws_op_queue
1029 .event()
1030 .try_clone()
1031 .expect("failed to clone queue event");
1032 handle_ws_op_queue(
1033 ws_op_queue,
1034 EventAsync::new(ws_op_queue_evt, &ex).expect("failed to create async event"),
1035 ws_op_rx,
1036 state.clone(),
1037 stop_rx,
1038 )
1039 .left_future()
1040 } else {
1041 std::future::pending().right_future()
1042 };
1043 let ws_op = ws_op.fuse();
1044 pin_mut!(ws_op);
1045
1046 // Future to handle command messages that resize the balloon.
1047 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1048 let command = handle_command_tube(
1049 &command_tube,
1050 interrupt.clone(),
1051 state.clone(),
1052 stats_tx,
1053 ws_op_tx,
1054 stop_rx,
1055 );
1056 pin_mut!(command);
1057
1058 // Process any requests to resample the irq value.
1059 let resample = async_utils::handle_irq_resample(&ex, interrupt.clone());
1060 pin_mut!(resample);
1061
1062 // Send a message if balloon target reached event is triggered.
1063 let target_reached = handle_target_reached(
1064 &ex,
1065 target_reached_evt,
1066 #[cfg(windows)]
1067 &vm_memory_client,
1068 );
1069 pin_mut!(target_reached);
1070
1071 // Exit if the kill event is triggered.
1072 let kill = async_utils::await_and_exit(&ex, kill_evt);
1073 pin_mut!(kill);
1074
1075 let pending_adjusted = handle_pending_adjusted_responses(
1076 EventAsync::new(pending_adjusted_response_event, &ex)
1077 .expect("failed to create async event"),
1078 &command_tube,
1079 state,
1080 );
1081 pin_mut!(pending_adjusted);
1082
1083 let res = ex.run_until(async {
1084 select! {
1085 _ = kill.fuse() => (),
1086 _ = inflate => return Err(anyhow!("inflate stopped unexpectedly")),
1087 _ = deflate => return Err(anyhow!("deflate stopped unexpectedly")),
1088 _ = stats => return Err(anyhow!("stats stopped unexpectedly")),
1089 _ = reporting => return Err(anyhow!("reporting stopped unexpectedly")),
1090 _ = command.fuse() => return Err(anyhow!("command stopped unexpectedly")),
1091 _ = ws_op => return Err(anyhow!("ws_op stopped unexpectedly")),
1092 _ = resample.fuse() => return Err(anyhow!("resample stopped unexpectedly")),
1093 _ = pending_adjusted.fuse() => return Err(anyhow!("pending_adjusted stopped unexpectedly")),
1094 _ = ws_data => return Err(anyhow!("ws_data stopped unexpectedly")),
1095 _ = target_reached.fuse() => return Err(anyhow!("target_reached stopped unexpectedly")),
1096 }
1097
1098 // Worker is shutting down. To recover the queues, we have to signal
1099 // all the queue futures to exit.
1100 for stop_tx in stop_queue_oneshots {
1101 if stop_tx.send(()).is_err() {
1102 return Err(anyhow!("failed to request stop for queue future"));
1103 }
1104 }
1105
1106 // Collect all the queues (awaiting any queue future should now
1107 // return its Queue immediately).
1108 let mut paused_queues = PausedQueues::new(
1109 inflate.await,
1110 deflate.await,
1111 );
1112 if has_reporting_queue {
1113 paused_queues.reporting = Some(reporting.await);
1114 }
1115 if has_stats_queue {
1116 paused_queues.stats = Some(stats.await);
1117 }
1118 if has_ws_data_queue {
1119 paused_queues.ws_data = Some(ws_data.await.context("failed to stop ws_data queue")?);
1120 }
1121 if has_ws_op_queue {
1122 paused_queues.ws_op = Some(ws_op.await.context("failed to stop ws_op queue")?);
1123 }
1124 Ok(paused_queues)
1125 });
1126
1127 match res {
1128 Err(e) => {
1129 error!("error happened in executor: {}", e);
1130 None
1131 }
1132 Ok(main_future_res) => match main_future_res {
1133 Ok(paused_queues) => Some(paused_queues),
1134 Err(e) => {
1135 error!("error happened in main balloon future: {}", e);
1136 None
1137 }
1138 },
1139 }
1140 };
1141
1142 WorkerReturn {
1143 command_tube: command_tube.into(),
1144 paused_queues,
1145 release_memory_tube,
1146 #[cfg(feature = "registered_events")]
1147 registered_evt_q,
1148 #[cfg(windows)]
1149 vm_memory_client,
1150 }
1151 }
1152
handle_target_reached( ex: &Executor, target_reached_evt: Event, #[cfg(windows)] vm_memory_client: &VmMemoryClient, ) -> anyhow::Result<()>1153 async fn handle_target_reached(
1154 ex: &Executor,
1155 target_reached_evt: Event,
1156 #[cfg(windows)] vm_memory_client: &VmMemoryClient,
1157 ) -> anyhow::Result<()> {
1158 let event_async =
1159 EventAsync::new(target_reached_evt, ex).context("failed to create EventAsync")?;
1160 loop {
1161 // Wait for target reached trigger.
1162 let _ = event_async.next_val().await;
1163 // Send the message to vm_control on the event. We don't have to read the current
1164 // size yet.
1165 sys::balloon_target_reached(
1166 0,
1167 #[cfg(windows)]
1168 vm_memory_client,
1169 );
1170 }
1171 // The above loop will never terminate and there is no reason to terminate it either. However,
1172 // the function is used in an executor that expects a Result<> return. Make sure that clippy
1173 // doesn't enforce the unreachable_code condition.
1174 #[allow(unreachable_code)]
1175 Ok(())
1176 }
1177
1178 /// Virtio device for memory balloon inflation/deflation.
1179 pub struct Balloon {
1180 command_tube: Option<Tube>,
1181 #[cfg(windows)]
1182 vm_memory_client: Option<VmMemoryClient>,
1183 release_memory_tube: Option<Tube>,
1184 pending_adjusted_response_event: Event,
1185 state: Arc<AsyncRwLock<BalloonState>>,
1186 features: u64,
1187 acked_features: u64,
1188 worker_thread: Option<WorkerThread<WorkerReturn>>,
1189 #[cfg(feature = "registered_events")]
1190 registered_evt_q: Option<SendTube>,
1191 ws_num_bins: u8,
1192 target_reached_evt: Option<Event>,
1193 }
1194
1195 /// Snapshot of the [Balloon] state.
1196 #[derive(Serialize, Deserialize)]
1197 struct BalloonSnapshot {
1198 state: BalloonState,
1199 features: u64,
1200 acked_features: u64,
1201 ws_num_bins: u8,
1202 }
1203
1204 impl Balloon {
1205 /// Creates a new virtio balloon device.
1206 /// To let Balloon able to successfully release the memory which are pinned
1207 /// by CoIOMMU to host, the release_memory_tube will be used to send the inflate
1208 /// ranges to CoIOMMU with UnpinRequest/UnpinResponse messages, so that The
1209 /// memory in the inflate range can be unpinned first.
new( base_features: u64, command_tube: Tube, #[cfg(windows)] vm_memory_client: VmMemoryClient, release_memory_tube: Option<Tube>, init_balloon_size: u64, enabled_features: u64, #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>, ws_num_bins: u8, ) -> Result<Balloon>1210 pub fn new(
1211 base_features: u64,
1212 command_tube: Tube,
1213 #[cfg(windows)] vm_memory_client: VmMemoryClient,
1214 release_memory_tube: Option<Tube>,
1215 init_balloon_size: u64,
1216 enabled_features: u64,
1217 #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
1218 ws_num_bins: u8,
1219 ) -> Result<Balloon> {
1220 let features = base_features
1221 | 1 << VIRTIO_BALLOON_F_MUST_TELL_HOST
1222 | 1 << VIRTIO_BALLOON_F_STATS_VQ
1223 | 1 << VIRTIO_BALLOON_F_DEFLATE_ON_OOM
1224 | enabled_features;
1225
1226 Ok(Balloon {
1227 command_tube: Some(command_tube),
1228 #[cfg(windows)]
1229 vm_memory_client: Some(vm_memory_client),
1230 release_memory_tube,
1231 pending_adjusted_response_event: Event::new().map_err(BalloonError::CreatingEvent)?,
1232 state: Arc::new(AsyncRwLock::new(BalloonState {
1233 num_pages: (init_balloon_size >> VIRTIO_BALLOON_PFN_SHIFT) as u32,
1234 actual_pages: 0,
1235 failable_update: false,
1236 pending_adjusted_responses: VecDeque::new(),
1237 expecting_ws: false,
1238 })),
1239 worker_thread: None,
1240 features,
1241 acked_features: 0,
1242 #[cfg(feature = "registered_events")]
1243 registered_evt_q,
1244 ws_num_bins,
1245 target_reached_evt: None,
1246 })
1247 }
1248
get_config(&self) -> virtio_balloon_config1249 fn get_config(&self) -> virtio_balloon_config {
1250 let state = block_on(self.state.lock());
1251 virtio_balloon_config {
1252 num_pages: state.num_pages.into(),
1253 actual: state.actual_pages.into(),
1254 // crosvm does not (currently) use free_page_hint_cmd_id or
1255 // poison_val, but they must be present in the right order and size
1256 // for the virtio-balloon driver in the guest to deserialize the
1257 // config correctly.
1258 free_page_hint_cmd_id: 0.into(),
1259 poison_val: 0.into(),
1260 ws_num_bins: self.ws_num_bins,
1261 _reserved: [0, 0, 0],
1262 }
1263 }
1264
stop_worker(&mut self) -> StoppedWorker<PausedQueues>1265 fn stop_worker(&mut self) -> StoppedWorker<PausedQueues> {
1266 if let Some(worker_thread) = self.worker_thread.take() {
1267 let worker_ret = worker_thread.stop();
1268 self.release_memory_tube = worker_ret.release_memory_tube;
1269 self.command_tube = Some(worker_ret.command_tube);
1270 #[cfg(feature = "registered_events")]
1271 {
1272 self.registered_evt_q = worker_ret.registered_evt_q;
1273 }
1274 #[cfg(windows)]
1275 {
1276 self.vm_memory_client = Some(worker_ret.vm_memory_client);
1277 }
1278
1279 if let Some(queues) = worker_ret.paused_queues {
1280 StoppedWorker::WithQueues(Box::new(queues))
1281 } else {
1282 StoppedWorker::MissingQueues
1283 }
1284 } else {
1285 StoppedWorker::AlreadyStopped
1286 }
1287 }
1288
1289 /// Given a filtered queue vector from [VirtioDevice::activate], extract
1290 /// the queues (accounting for queues that are missing because the features
1291 /// are not negotiated) into a structure that is easier to work with.
get_queues_from_map( &self, mut queues: BTreeMap<usize, Queue>, ) -> anyhow::Result<BalloonQueues>1292 fn get_queues_from_map(
1293 &self,
1294 mut queues: BTreeMap<usize, Queue>,
1295 ) -> anyhow::Result<BalloonQueues> {
1296 fn pop_queue(
1297 queues: &mut BTreeMap<usize, Queue>,
1298 expected_index: usize,
1299 name: &str,
1300 ) -> anyhow::Result<Queue> {
1301 let (queue_index, queue) = queues
1302 .pop_first()
1303 .with_context(|| format!("missing {}", name))?;
1304
1305 if queue_index == expected_index {
1306 debug!("{name} index {queue_index}");
1307 } else {
1308 warn!("expected {name} index {expected_index}, got {queue_index}");
1309 }
1310
1311 Ok(queue)
1312 }
1313
1314 // WARNING: We use `pop_first` instead of explicitly using the indices from the virtio spec
1315 // because the Linux virtio drivers only "allocates" queue indices that are used, so queues
1316 // need to be removed in order of ascending virtqueue index.
1317 let inflate_queue = pop_queue(&mut queues, INFLATEQ, "inflateq")?;
1318 let deflate_queue = pop_queue(&mut queues, DEFLATEQ, "deflateq")?;
1319 let mut queue_struct = BalloonQueues::new(inflate_queue, deflate_queue);
1320
1321 if self.acked_features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1322 queue_struct.stats = Some(pop_queue(&mut queues, STATSQ, "statsq")?);
1323 }
1324 if self.acked_features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1325 queue_struct.reporting = Some(pop_queue(&mut queues, REPORTING_VQ, "reporting_vq")?);
1326 }
1327 if self.acked_features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1328 queue_struct.ws_data = Some(pop_queue(&mut queues, WS_DATA_VQ, "ws_data_vq")?);
1329 queue_struct.ws_op = Some(pop_queue(&mut queues, WS_OP_VQ, "ws_op_vq")?);
1330 }
1331
1332 if !queues.is_empty() {
1333 return Err(anyhow!("unexpected queues {:?}", queues.into_keys()));
1334 }
1335
1336 Ok(queue_struct)
1337 }
1338
start_worker( &mut self, mem: GuestMemory, interrupt: Interrupt, queues: BalloonQueues, ) -> anyhow::Result<()>1339 fn start_worker(
1340 &mut self,
1341 mem: GuestMemory,
1342 interrupt: Interrupt,
1343 queues: BalloonQueues,
1344 ) -> anyhow::Result<()> {
1345 let (self_target_reached_evt, target_reached_evt) = Event::new()
1346 .and_then(|e| Ok((e.try_clone()?, e)))
1347 .context("failed to create target_reached Event pair: {}")?;
1348 self.target_reached_evt = Some(self_target_reached_evt);
1349
1350 let state = self.state.clone();
1351
1352 let command_tube = self.command_tube.take().unwrap();
1353
1354 #[cfg(windows)]
1355 let vm_memory_client = self.vm_memory_client.take().unwrap();
1356 let release_memory_tube = self.release_memory_tube.take();
1357 #[cfg(feature = "registered_events")]
1358 let registered_evt_q = self.registered_evt_q.take();
1359 let pending_adjusted_response_event = self
1360 .pending_adjusted_response_event
1361 .try_clone()
1362 .context("failed to clone Event")?;
1363
1364 self.worker_thread = Some(WorkerThread::start("v_balloon", move |kill_evt| {
1365 run_worker(
1366 queues.inflate,
1367 queues.deflate,
1368 queues.stats,
1369 queues.reporting,
1370 queues.ws_data,
1371 queues.ws_op,
1372 command_tube,
1373 #[cfg(windows)]
1374 vm_memory_client,
1375 release_memory_tube,
1376 interrupt,
1377 kill_evt,
1378 target_reached_evt,
1379 pending_adjusted_response_event,
1380 mem,
1381 state,
1382 #[cfg(feature = "registered_events")]
1383 registered_evt_q,
1384 )
1385 }));
1386
1387 Ok(())
1388 }
1389 }
1390
1391 impl VirtioDevice for Balloon {
keep_rds(&self) -> Vec<RawDescriptor>1392 fn keep_rds(&self) -> Vec<RawDescriptor> {
1393 let mut rds = Vec::new();
1394 if let Some(command_tube) = &self.command_tube {
1395 rds.push(command_tube.as_raw_descriptor());
1396 }
1397 if let Some(release_memory_tube) = &self.release_memory_tube {
1398 rds.push(release_memory_tube.as_raw_descriptor());
1399 }
1400 #[cfg(feature = "registered_events")]
1401 if let Some(registered_evt_q) = &self.registered_evt_q {
1402 rds.push(registered_evt_q.as_raw_descriptor());
1403 }
1404 rds.push(self.pending_adjusted_response_event.as_raw_descriptor());
1405 rds
1406 }
1407
device_type(&self) -> DeviceType1408 fn device_type(&self) -> DeviceType {
1409 DeviceType::Balloon
1410 }
1411
queue_max_sizes(&self) -> &[u16]1412 fn queue_max_sizes(&self) -> &[u16] {
1413 QUEUE_SIZES
1414 }
1415
read_config(&self, offset: u64, data: &mut [u8])1416 fn read_config(&self, offset: u64, data: &mut [u8]) {
1417 copy_config(data, 0, self.get_config().as_bytes(), offset);
1418 }
1419
write_config(&mut self, offset: u64, data: &[u8])1420 fn write_config(&mut self, offset: u64, data: &[u8]) {
1421 let mut config = self.get_config();
1422 copy_config(config.as_bytes_mut(), offset, data, 0);
1423 let mut state = block_on(self.state.lock());
1424 state.actual_pages = config.actual.to_native();
1425
1426 // If balloon has updated to the requested memory, let the hypervisor know.
1427 if config.num_pages == config.actual {
1428 debug!(
1429 "sending target reached event at {}",
1430 u32::from(config.num_pages)
1431 );
1432 self.target_reached_evt.as_ref().map(|e| e.signal());
1433 }
1434 if state.failable_update && state.actual_pages == state.num_pages {
1435 state.failable_update = false;
1436 let num_pages = state.num_pages;
1437 state.pending_adjusted_responses.push_back(num_pages);
1438 let _ = self.pending_adjusted_response_event.signal();
1439 }
1440 }
1441
features(&self) -> u641442 fn features(&self) -> u64 {
1443 self.features
1444 }
1445
ack_features(&mut self, mut value: u64)1446 fn ack_features(&mut self, mut value: u64) {
1447 if value & !self.features != 0 {
1448 warn!("virtio_balloon got unknown feature ack {:x}", value);
1449 value &= self.features;
1450 }
1451 self.acked_features |= value;
1452 }
1453
activate( &mut self, mem: GuestMemory, interrupt: Interrupt, queues: BTreeMap<usize, Queue>, ) -> anyhow::Result<()>1454 fn activate(
1455 &mut self,
1456 mem: GuestMemory,
1457 interrupt: Interrupt,
1458 queues: BTreeMap<usize, Queue>,
1459 ) -> anyhow::Result<()> {
1460 let queues = self.get_queues_from_map(queues)?;
1461 self.start_worker(mem, interrupt, queues)
1462 }
1463
reset(&mut self) -> anyhow::Result<()>1464 fn reset(&mut self) -> anyhow::Result<()> {
1465 let _worker = self.stop_worker();
1466 Ok(())
1467 }
1468
virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>>1469 fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
1470 match self.stop_worker() {
1471 StoppedWorker::WithQueues(paused_queues) => Ok(Some(paused_queues.into())),
1472 StoppedWorker::MissingQueues => {
1473 anyhow::bail!("balloon queue workers did not stop cleanly.")
1474 }
1475 StoppedWorker::AlreadyStopped => {
1476 // Device hasn't been activated.
1477 Ok(None)
1478 }
1479 }
1480 }
1481
virtio_wake( &mut self, queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>, ) -> anyhow::Result<()>1482 fn virtio_wake(
1483 &mut self,
1484 queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
1485 ) -> anyhow::Result<()> {
1486 if let Some((mem, interrupt, queues)) = queues_state {
1487 if queues.len() < 2 {
1488 anyhow::bail!("{} queues were found, but an activated balloon must have at least 2 active queues.", queues.len());
1489 }
1490
1491 let balloon_queues = self.get_queues_from_map(queues)?;
1492 self.start_worker(mem, interrupt, balloon_queues)?;
1493 }
1494 Ok(())
1495 }
1496
virtio_snapshot(&mut self) -> anyhow::Result<serde_json::Value>1497 fn virtio_snapshot(&mut self) -> anyhow::Result<serde_json::Value> {
1498 let state = self
1499 .state
1500 .lock()
1501 .now_or_never()
1502 .context("failed to acquire balloon lock")?;
1503 serde_json::to_value(BalloonSnapshot {
1504 features: self.features,
1505 acked_features: self.acked_features,
1506 state: state.clone(),
1507 ws_num_bins: self.ws_num_bins,
1508 })
1509 .context("failed to serialize balloon state")
1510 }
1511
virtio_restore(&mut self, data: serde_json::Value) -> anyhow::Result<()>1512 fn virtio_restore(&mut self, data: serde_json::Value) -> anyhow::Result<()> {
1513 let snap: BalloonSnapshot = serde_json::from_value(data).context("error deserializing")?;
1514 if snap.features != self.features {
1515 anyhow::bail!(
1516 "balloon: expected features to match, but they did not. Live: {:?}, snapshot {:?}",
1517 self.features,
1518 snap.features,
1519 );
1520 }
1521
1522 let mut state = self
1523 .state
1524 .lock()
1525 .now_or_never()
1526 .context("failed to acquire balloon lock")?;
1527 *state = snap.state;
1528 self.ws_num_bins = snap.ws_num_bins;
1529 self.acked_features = snap.acked_features;
1530 Ok(())
1531 }
1532 }
1533
1534 #[cfg(test)]
1535 mod tests {
1536 use super::*;
1537 use crate::suspendable_virtio_tests;
1538 use crate::virtio::descriptor_utils::create_descriptor_chain;
1539 use crate::virtio::descriptor_utils::DescriptorType;
1540
1541 #[test]
desc_parsing_inflate()1542 fn desc_parsing_inflate() {
1543 // Check that the memory addresses are parsed correctly by 'handle_address_chain' and passed
1544 // to the closure.
1545 let memory_start_addr = GuestAddress(0x0);
1546 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1547 memory
1548 .write_obj_at_addr(0x10u32, GuestAddress(0x100))
1549 .unwrap();
1550 memory
1551 .write_obj_at_addr(0xaa55aa55u32, GuestAddress(0x104))
1552 .unwrap();
1553
1554 let mut chain = create_descriptor_chain(
1555 &memory,
1556 GuestAddress(0x0),
1557 GuestAddress(0x100),
1558 vec![(DescriptorType::Readable, 8)],
1559 0,
1560 )
1561 .expect("create_descriptor_chain failed");
1562
1563 let mut addrs = Vec::new();
1564 let res = handle_address_chain(None, &mut chain, &mut |guest_address, len| {
1565 addrs.push((guest_address, len));
1566 });
1567 assert!(res.is_ok());
1568 assert_eq!(addrs.len(), 2);
1569 assert_eq!(
1570 addrs[0].0,
1571 GuestAddress(0x10u64 << VIRTIO_BALLOON_PFN_SHIFT)
1572 );
1573 assert_eq!(
1574 addrs[1].0,
1575 GuestAddress(0xaa55aa55u64 << VIRTIO_BALLOON_PFN_SHIFT)
1576 );
1577 }
1578
1579 struct BalloonContext {
1580 _ctrl_tube: Tube,
1581 #[cfg(windows)]
1582 _mem_client_tube: Tube,
1583 }
1584
modify_device(_balloon_context: &mut BalloonContext, balloon: &mut Balloon)1585 fn modify_device(_balloon_context: &mut BalloonContext, balloon: &mut Balloon) {
1586 balloon.ws_num_bins = !balloon.ws_num_bins;
1587 }
1588
create_device() -> (BalloonContext, Balloon)1589 fn create_device() -> (BalloonContext, Balloon) {
1590 let (_ctrl_tube, ctrl_tube_device) = Tube::pair().unwrap();
1591 #[cfg(windows)]
1592 let (_mem_client_tube, mem_client_tube_device) = Tube::pair().unwrap();
1593 (
1594 BalloonContext {
1595 _ctrl_tube,
1596 #[cfg(windows)]
1597 _mem_client_tube,
1598 },
1599 Balloon::new(
1600 0,
1601 ctrl_tube_device,
1602 #[cfg(windows)]
1603 VmMemoryClient::new(mem_client_tube_device),
1604 None,
1605 1024,
1606 0,
1607 #[cfg(feature = "registered_events")]
1608 None,
1609 0,
1610 )
1611 .unwrap(),
1612 )
1613 }
1614
1615 suspendable_virtio_tests!(balloon, create_device, 2, modify_device);
1616 }
1617