xref: /aosp_15_r20/external/crosvm/devices/src/virtio/snd/common_backend/mod.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 // virtio-sound spec: https://github.com/oasis-tcs/virtio-spec/blob/master/virtio-sound.tex
6 
7 use std::collections::BTreeMap;
8 use std::io;
9 use std::rc::Rc;
10 use std::sync::Arc;
11 
12 use anyhow::anyhow;
13 use anyhow::Context;
14 use audio_streams::BoxError;
15 use base::debug;
16 use base::error;
17 use base::warn;
18 use base::AsRawDescriptor;
19 use base::Descriptor;
20 use base::Error as SysError;
21 use base::Event;
22 use base::RawDescriptor;
23 use base::WorkerThread;
24 use cros_async::block_on;
25 use cros_async::sync::Condvar;
26 use cros_async::sync::RwLock as AsyncRwLock;
27 use cros_async::AsyncError;
28 use cros_async::EventAsync;
29 use cros_async::Executor;
30 use futures::channel::mpsc;
31 use futures::channel::oneshot;
32 use futures::channel::oneshot::Canceled;
33 use futures::future::FusedFuture;
34 use futures::join;
35 use futures::pin_mut;
36 use futures::select;
37 use futures::FutureExt;
38 use serde::Deserialize;
39 use serde::Serialize;
40 use thiserror::Error as ThisError;
41 use vm_memory::GuestMemory;
42 use zerocopy::AsBytes;
43 
44 use crate::virtio::async_utils;
45 use crate::virtio::copy_config;
46 use crate::virtio::device_constants::snd::virtio_snd_config;
47 use crate::virtio::snd::common_backend::async_funcs::*;
48 use crate::virtio::snd::common_backend::stream_info::StreamInfo;
49 use crate::virtio::snd::common_backend::stream_info::StreamInfoBuilder;
50 use crate::virtio::snd::common_backend::stream_info::StreamInfoSnapshot;
51 use crate::virtio::snd::constants::*;
52 use crate::virtio::snd::file_backend::create_file_stream_source_generators;
53 use crate::virtio::snd::file_backend::Error as FileError;
54 use crate::virtio::snd::layout::*;
55 use crate::virtio::snd::null_backend::create_null_stream_source_generators;
56 use crate::virtio::snd::parameters::Parameters;
57 use crate::virtio::snd::parameters::StreamSourceBackend;
58 use crate::virtio::snd::sys::create_stream_source_generators as sys_create_stream_source_generators;
59 use crate::virtio::snd::sys::set_audio_thread_priority;
60 use crate::virtio::snd::sys::SysAsyncStreamObjects;
61 use crate::virtio::snd::sys::SysAudioStreamSourceGenerator;
62 use crate::virtio::snd::sys::SysDirectionOutput;
63 use crate::virtio::DescriptorChain;
64 use crate::virtio::DeviceType;
65 use crate::virtio::Interrupt;
66 use crate::virtio::Queue;
67 use crate::virtio::VirtioDevice;
68 
69 pub mod async_funcs;
70 pub mod stream_info;
71 
72 // control + event + tx + rx queue
73 pub const MAX_QUEUE_NUM: usize = 4;
74 pub const MAX_VRING_LEN: u16 = 1024;
75 
76 #[derive(ThisError, Debug)]
77 pub enum Error {
78     /// next_async failed.
79     #[error("Failed to read descriptor asynchronously: {0}")]
80     Async(AsyncError),
81     /// Creating stream failed.
82     #[error("Failed to create stream: {0}")]
83     CreateStream(BoxError),
84     /// Creating stream failed.
85     #[error("No stream source found.")]
86     EmptyStreamSource,
87     /// Creating kill event failed.
88     #[error("Failed to create kill event: {0}")]
89     CreateKillEvent(SysError),
90     /// Creating WaitContext failed.
91     #[error("Failed to create wait context: {0}")]
92     CreateWaitContext(SysError),
93     #[error("Failed to create file stream source generator")]
94     CreateFileStreamSourceGenerator(FileError),
95     /// Cloning kill event failed.
96     #[error("Failed to clone kill event: {0}")]
97     CloneKillEvent(SysError),
98     // Future error.
99     #[error("Unexpected error. Done was not triggered before dropped: {0}")]
100     DoneNotTriggered(Canceled),
101     /// Error reading message from queue.
102     #[error("Failed to read message: {0}")]
103     ReadMessage(io::Error),
104     /// Failed writing a response to a control message.
105     #[error("Failed to write message response: {0}")]
106     WriteResponse(io::Error),
107     // Mpsc read error.
108     #[error("Error in mpsc: {0}")]
109     MpscSend(futures::channel::mpsc::SendError),
110     // Oneshot send error.
111     #[error("Error in oneshot send")]
112     OneshotSend(()),
113     /// Stream not found.
114     #[error("stream id ({0}) < num_streams ({1})")]
115     StreamNotFound(usize, usize),
116     /// Fetch buffer error
117     #[error("Failed to get buffer from CRAS: {0}")]
118     FetchBuffer(BoxError),
119     /// Invalid buffer size
120     #[error("Invalid buffer size")]
121     InvalidBufferSize,
122     /// IoError
123     #[error("I/O failed: {0}")]
124     Io(io::Error),
125     /// Operation not supported.
126     #[error("Operation not supported")]
127     OperationNotSupported,
128     /// Writing to a buffer in the guest failed.
129     #[error("failed to write to buffer: {0}")]
130     WriteBuffer(io::Error),
131     // Invalid PCM worker state.
132     #[error("Invalid PCM worker state")]
133     InvalidPCMWorkerState,
134     // Invalid backend.
135     #[error("Backend is not implemented")]
136     InvalidBackend,
137     // Failed to generate StreamSource
138     #[error("Failed to generate stream source: {0}")]
139     GenerateStreamSource(BoxError),
140     // PCM worker unexpectedly quitted.
141     #[error("PCM worker quitted unexpectedly")]
142     PCMWorkerQuittedUnexpectedly,
143 }
144 
145 pub enum DirectionalStream {
146     Input(
147         usize, // `period_size` in `usize`
148         Box<dyn CaptureBufferReader>,
149     ),
150     Output(SysDirectionOutput),
151 }
152 
153 #[derive(Copy, Clone, std::cmp::PartialEq, Eq)]
154 pub enum WorkerStatus {
155     Pause = 0,
156     Running = 1,
157     Quit = 2,
158 }
159 
160 // Stores constant data
161 #[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Debug)]
162 pub struct SndData {
163     pub(crate) jack_info: Vec<virtio_snd_jack_info>,
164     pub(crate) pcm_info: Vec<virtio_snd_pcm_info>,
165     pub(crate) chmap_info: Vec<virtio_snd_chmap_info>,
166 }
167 
168 impl SndData {
pcm_info_len(&self) -> usize169     pub fn pcm_info_len(&self) -> usize {
170         self.pcm_info.len()
171     }
172 
pcm_info_iter(&self) -> std::slice::Iter<'_, virtio_snd_pcm_info>173     pub fn pcm_info_iter(&self) -> std::slice::Iter<'_, virtio_snd_pcm_info> {
174         self.pcm_info.iter()
175     }
176 }
177 
178 const SUPPORTED_FORMATS: u64 = 1 << VIRTIO_SND_PCM_FMT_U8
179     | 1 << VIRTIO_SND_PCM_FMT_S16
180     | 1 << VIRTIO_SND_PCM_FMT_S24
181     | 1 << VIRTIO_SND_PCM_FMT_S32;
182 const SUPPORTED_FRAME_RATES: u64 = 1 << VIRTIO_SND_PCM_RATE_8000
183     | 1 << VIRTIO_SND_PCM_RATE_11025
184     | 1 << VIRTIO_SND_PCM_RATE_16000
185     | 1 << VIRTIO_SND_PCM_RATE_22050
186     | 1 << VIRTIO_SND_PCM_RATE_32000
187     | 1 << VIRTIO_SND_PCM_RATE_44100
188     | 1 << VIRTIO_SND_PCM_RATE_48000;
189 
190 // Response from pcm_worker to pcm_queue
191 pub struct PcmResponse {
192     pub(crate) desc_chain: DescriptorChain,
193     pub(crate) status: virtio_snd_pcm_status, // response to the pcm message
194     pub(crate) done: Option<oneshot::Sender<()>>, // when pcm response is written to the queue
195 }
196 
197 pub struct VirtioSnd {
198     cfg: virtio_snd_config,
199     snd_data: SndData,
200     stream_info_builders: Vec<StreamInfoBuilder>,
201     avail_features: u64,
202     acked_features: u64,
203     queue_sizes: Box<[u16]>,
204     worker_thread: Option<WorkerThread<Result<WorkerReturn, String>>>,
205     keep_rds: Vec<Descriptor>,
206     streams_state: Option<Vec<StreamInfoSnapshot>>,
207     card_index: usize,
208 }
209 
210 #[derive(Serialize, Deserialize)]
211 struct VirtioSndSnapshot {
212     avail_features: u64,
213     acked_features: u64,
214     queue_sizes: Vec<u16>,
215     streams_state: Option<Vec<StreamInfoSnapshot>>,
216     snd_data: SndData,
217 }
218 
219 impl VirtioSnd {
new(base_features: u64, params: Parameters) -> Result<VirtioSnd, Error>220     pub fn new(base_features: u64, params: Parameters) -> Result<VirtioSnd, Error> {
221         let params = resize_parameters_pcm_device_config(params);
222         let cfg = hardcoded_virtio_snd_config(&params);
223         let snd_data = hardcoded_snd_data(&params);
224         let avail_features = base_features;
225         let mut keep_rds: Vec<RawDescriptor> = Vec::new();
226 
227         let stream_info_builders =
228             create_stream_info_builders(&params, &snd_data, &mut keep_rds, params.card_index)?;
229 
230         Ok(VirtioSnd {
231             cfg,
232             snd_data,
233             stream_info_builders,
234             avail_features,
235             acked_features: 0,
236             queue_sizes: vec![MAX_VRING_LEN; MAX_QUEUE_NUM].into_boxed_slice(),
237             worker_thread: None,
238             keep_rds: keep_rds.iter().map(|rd| Descriptor(*rd)).collect(),
239             streams_state: None,
240             card_index: params.card_index,
241         })
242     }
243 }
244 
create_stream_source_generators( params: &Parameters, snd_data: &SndData, keep_rds: &mut Vec<RawDescriptor>, ) -> Result<Vec<SysAudioStreamSourceGenerator>, Error>245 fn create_stream_source_generators(
246     params: &Parameters,
247     snd_data: &SndData,
248     keep_rds: &mut Vec<RawDescriptor>,
249 ) -> Result<Vec<SysAudioStreamSourceGenerator>, Error> {
250     let generators = match params.backend {
251         StreamSourceBackend::NULL => create_null_stream_source_generators(snd_data),
252         StreamSourceBackend::FILE => {
253             create_file_stream_source_generators(params, snd_data, keep_rds)
254                 .map_err(Error::CreateFileStreamSourceGenerator)?
255         }
256         StreamSourceBackend::Sys(backend) => {
257             sys_create_stream_source_generators(backend, params, snd_data)
258         }
259     };
260     Ok(generators)
261 }
262 
263 /// Creates [`StreamInfoBuilder`]s by calling [`create_stream_source_generators()`] then zip
264 /// them with [`crate::virtio::snd::parameters::PCMDeviceParameters`] from the params to set
265 /// the parameters on each [`StreamInfoBuilder`] (e.g. effects).
create_stream_info_builders( params: &Parameters, snd_data: &SndData, keep_rds: &mut Vec<RawDescriptor>, card_index: usize, ) -> Result<Vec<StreamInfoBuilder>, Error>266 pub(crate) fn create_stream_info_builders(
267     params: &Parameters,
268     snd_data: &SndData,
269     keep_rds: &mut Vec<RawDescriptor>,
270     card_index: usize,
271 ) -> Result<Vec<StreamInfoBuilder>, Error> {
272     Ok(create_stream_source_generators(params, snd_data, keep_rds)?
273         .into_iter()
274         .map(Arc::new)
275         .zip(snd_data.pcm_info_iter())
276         .map(|(generator, pcm_info)| {
277             let device_params = params.get_device_params(pcm_info).unwrap_or_default();
278             StreamInfo::builder(generator, card_index)
279                 .effects(device_params.effects.unwrap_or_default())
280         })
281         .collect())
282 }
283 
284 // To be used with hardcoded_snd_data
hardcoded_virtio_snd_config(params: &Parameters) -> virtio_snd_config285 pub fn hardcoded_virtio_snd_config(params: &Parameters) -> virtio_snd_config {
286     virtio_snd_config {
287         jacks: 0.into(),
288         streams: params.get_total_streams().into(),
289         chmaps: (params.num_output_devices * 3 + params.num_input_devices).into(),
290     }
291 }
292 
293 // To be used with hardcoded_virtio_snd_config
hardcoded_snd_data(params: &Parameters) -> SndData294 pub fn hardcoded_snd_data(params: &Parameters) -> SndData {
295     let jack_info: Vec<virtio_snd_jack_info> = Vec::new();
296     let mut pcm_info: Vec<virtio_snd_pcm_info> = Vec::new();
297     let mut chmap_info: Vec<virtio_snd_chmap_info> = Vec::new();
298 
299     for dev in 0..params.num_output_devices {
300         for _ in 0..params.num_output_streams {
301             pcm_info.push(virtio_snd_pcm_info {
302                 hdr: virtio_snd_info {
303                     hda_fn_nid: dev.into(),
304                 },
305                 features: 0.into(), /* 1 << VIRTIO_SND_PCM_F_XXX */
306                 formats: SUPPORTED_FORMATS.into(),
307                 rates: SUPPORTED_FRAME_RATES.into(),
308                 direction: VIRTIO_SND_D_OUTPUT,
309                 channels_min: 1,
310                 channels_max: 6,
311                 padding: [0; 5],
312             });
313         }
314     }
315     for dev in 0..params.num_input_devices {
316         for _ in 0..params.num_input_streams {
317             pcm_info.push(virtio_snd_pcm_info {
318                 hdr: virtio_snd_info {
319                     hda_fn_nid: dev.into(),
320                 },
321                 features: 0.into(), /* 1 << VIRTIO_SND_PCM_F_XXX */
322                 formats: SUPPORTED_FORMATS.into(),
323                 rates: SUPPORTED_FRAME_RATES.into(),
324                 direction: VIRTIO_SND_D_INPUT,
325                 channels_min: 1,
326                 channels_max: 2,
327                 padding: [0; 5],
328             });
329         }
330     }
331     // Use stereo channel map.
332     let mut positions = [VIRTIO_SND_CHMAP_NONE; VIRTIO_SND_CHMAP_MAX_SIZE];
333     positions[0] = VIRTIO_SND_CHMAP_FL;
334     positions[1] = VIRTIO_SND_CHMAP_FR;
335     for dev in 0..params.num_output_devices {
336         chmap_info.push(virtio_snd_chmap_info {
337             hdr: virtio_snd_info {
338                 hda_fn_nid: dev.into(),
339             },
340             direction: VIRTIO_SND_D_OUTPUT,
341             channels: 2,
342             positions,
343         });
344     }
345     for dev in 0..params.num_input_devices {
346         chmap_info.push(virtio_snd_chmap_info {
347             hdr: virtio_snd_info {
348                 hda_fn_nid: dev.into(),
349             },
350             direction: VIRTIO_SND_D_INPUT,
351             channels: 2,
352             positions,
353         });
354     }
355     positions[2] = VIRTIO_SND_CHMAP_RL;
356     positions[3] = VIRTIO_SND_CHMAP_RR;
357     for dev in 0..params.num_output_devices {
358         chmap_info.push(virtio_snd_chmap_info {
359             hdr: virtio_snd_info {
360                 hda_fn_nid: dev.into(),
361             },
362             direction: VIRTIO_SND_D_OUTPUT,
363             channels: 4,
364             positions,
365         });
366     }
367     positions[2] = VIRTIO_SND_CHMAP_FC;
368     positions[3] = VIRTIO_SND_CHMAP_LFE;
369     positions[4] = VIRTIO_SND_CHMAP_RL;
370     positions[5] = VIRTIO_SND_CHMAP_RR;
371     for dev in 0..params.num_output_devices {
372         chmap_info.push(virtio_snd_chmap_info {
373             hdr: virtio_snd_info {
374                 hda_fn_nid: dev.into(),
375             },
376             direction: VIRTIO_SND_D_OUTPUT,
377             channels: 6,
378             positions,
379         });
380     }
381 
382     SndData {
383         jack_info,
384         pcm_info,
385         chmap_info,
386     }
387 }
388 
resize_parameters_pcm_device_config(mut params: Parameters) -> Parameters389 fn resize_parameters_pcm_device_config(mut params: Parameters) -> Parameters {
390     if params.output_device_config.len() > params.num_output_devices as usize {
391         warn!("Truncating output device config due to length > number of output devices");
392     }
393     params
394         .output_device_config
395         .resize_with(params.num_output_devices as usize, Default::default);
396 
397     if params.input_device_config.len() > params.num_input_devices as usize {
398         warn!("Truncating input device config due to length > number of input devices");
399     }
400     params
401         .input_device_config
402         .resize_with(params.num_input_devices as usize, Default::default);
403 
404     params
405 }
406 
407 impl VirtioDevice for VirtioSnd {
keep_rds(&self) -> Vec<RawDescriptor>408     fn keep_rds(&self) -> Vec<RawDescriptor> {
409         self.keep_rds
410             .iter()
411             .map(|descr| descr.as_raw_descriptor())
412             .collect()
413     }
414 
device_type(&self) -> DeviceType415     fn device_type(&self) -> DeviceType {
416         DeviceType::Sound
417     }
418 
queue_max_sizes(&self) -> &[u16]419     fn queue_max_sizes(&self) -> &[u16] {
420         &self.queue_sizes
421     }
422 
features(&self) -> u64423     fn features(&self) -> u64 {
424         self.avail_features
425     }
426 
ack_features(&mut self, mut v: u64)427     fn ack_features(&mut self, mut v: u64) {
428         // Check if the guest is ACK'ing a feature that we didn't claim to have.
429         let unrequested_features = v & !self.avail_features;
430         if unrequested_features != 0 {
431             warn!("virtio_fs got unknown feature ack: {:x}", v);
432 
433             // Don't count these features as acked.
434             v &= !unrequested_features;
435         }
436         self.acked_features |= v;
437     }
438 
read_config(&self, offset: u64, data: &mut [u8])439     fn read_config(&self, offset: u64, data: &mut [u8]) {
440         copy_config(data, 0, self.cfg.as_bytes(), offset)
441     }
442 
activate( &mut self, _guest_mem: GuestMemory, interrupt: Interrupt, queues: BTreeMap<usize, Queue>, ) -> anyhow::Result<()>443     fn activate(
444         &mut self,
445         _guest_mem: GuestMemory,
446         interrupt: Interrupt,
447         queues: BTreeMap<usize, Queue>,
448     ) -> anyhow::Result<()> {
449         if queues.len() != self.queue_sizes.len() {
450             return Err(anyhow!(
451                 "snd: expected {} queues, got {}",
452                 self.queue_sizes.len(),
453                 queues.len()
454             ));
455         }
456 
457         let snd_data = self.snd_data.clone();
458         let stream_info_builders = self.stream_info_builders.to_vec();
459         let streams_state = self.streams_state.take();
460         let card_index = self.card_index;
461         self.worker_thread = Some(WorkerThread::start("v_snd_common", move |kill_evt| {
462             let _thread_priority_handle = set_audio_thread_priority();
463             if let Err(e) = _thread_priority_handle {
464                 warn!("Failed to set audio thread to real time: {}", e);
465             };
466             run_worker(
467                 interrupt,
468                 queues,
469                 snd_data,
470                 kill_evt,
471                 stream_info_builders,
472                 streams_state,
473                 card_index,
474             )
475         }));
476 
477         Ok(())
478     }
479 
reset(&mut self) -> anyhow::Result<()>480     fn reset(&mut self) -> anyhow::Result<()> {
481         if let Some(worker_thread) = self.worker_thread.take() {
482             let _ = worker_thread.stop();
483         }
484 
485         Ok(())
486     }
487 
virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>>488     fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
489         if let Some(worker_thread) = self.worker_thread.take() {
490             let worker = worker_thread.stop().unwrap();
491             self.snd_data = worker.snd_data;
492             self.streams_state = Some(worker.streams_state);
493             return Ok(Some(BTreeMap::from_iter(
494                 worker.queues.into_iter().enumerate(),
495             )));
496         }
497         Ok(None)
498     }
499 
virtio_wake( &mut self, device_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>, ) -> anyhow::Result<()>500     fn virtio_wake(
501         &mut self,
502         device_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
503     ) -> anyhow::Result<()> {
504         match device_state {
505             None => Ok(()),
506             Some((mem, interrupt, queues)) => {
507                 // TODO: activate is just what we want at the moment, but we should probably move
508                 // it into a "start workers" function to make it obvious that it isn't strictly
509                 // used for activate events.
510                 self.activate(mem, interrupt, queues)?;
511                 Ok(())
512             }
513         }
514     }
515 
virtio_snapshot(&mut self) -> anyhow::Result<serde_json::Value>516     fn virtio_snapshot(&mut self) -> anyhow::Result<serde_json::Value> {
517         let streams_state = if let Some(states) = &self.streams_state {
518             let mut state_vec = Vec::new();
519             for state in states {
520                 state_vec.push(state.clone());
521             }
522             Some(state_vec)
523         } else {
524             None
525         };
526         serde_json::to_value(VirtioSndSnapshot {
527             avail_features: self.avail_features,
528             acked_features: self.acked_features,
529             queue_sizes: self.queue_sizes.to_vec(),
530             streams_state,
531             snd_data: self.snd_data.clone(),
532         })
533         .context("failed to Serialize Sound device")
534     }
535 
virtio_restore(&mut self, data: serde_json::Value) -> anyhow::Result<()>536     fn virtio_restore(&mut self, data: serde_json::Value) -> anyhow::Result<()> {
537         let mut deser: VirtioSndSnapshot =
538             serde_json::from_value(data).context("failed to Deserialize Sound device")?;
539         anyhow::ensure!(
540             deser.avail_features == self.avail_features,
541             "avail features doesn't match on restore: expected: {}, got: {}",
542             deser.avail_features,
543             self.avail_features
544         );
545         anyhow::ensure!(
546             deser.queue_sizes == self.queue_sizes.to_vec(),
547             "queue sizes doesn't match on restore: expected: {:?}, got: {:?}",
548             deser.queue_sizes,
549             self.queue_sizes.to_vec()
550         );
551         self.acked_features = deser.acked_features;
552         anyhow::ensure!(
553             deser.snd_data == self.snd_data,
554             "snd data doesn't match on restore: expected: {:?}, got: {:?}",
555             deser.snd_data,
556             self.snd_data
557         );
558         self.acked_features = deser.acked_features;
559         self.streams_state = deser.streams_state.take();
560         Ok(())
561     }
562 }
563 
564 #[derive(PartialEq)]
565 enum LoopState {
566     Continue,
567     Break,
568 }
569 
run_worker( interrupt: Interrupt, queues: BTreeMap<usize, Queue>, snd_data: SndData, kill_evt: Event, stream_info_builders: Vec<StreamInfoBuilder>, streams_state: Option<Vec<StreamInfoSnapshot>>, card_index: usize, ) -> Result<WorkerReturn, String>570 fn run_worker(
571     interrupt: Interrupt,
572     queues: BTreeMap<usize, Queue>,
573     snd_data: SndData,
574     kill_evt: Event,
575     stream_info_builders: Vec<StreamInfoBuilder>,
576     streams_state: Option<Vec<StreamInfoSnapshot>>,
577     card_index: usize,
578 ) -> Result<WorkerReturn, String> {
579     let ex = Executor::new().expect("Failed to create an executor");
580 
581     if snd_data.pcm_info_len() != stream_info_builders.len() {
582         error!(
583             "snd: expected {} streams, got {}",
584             snd_data.pcm_info_len(),
585             stream_info_builders.len(),
586         );
587     }
588     let streams: Vec<AsyncRwLock<StreamInfo>> = stream_info_builders
589         .into_iter()
590         .map(StreamInfoBuilder::build)
591         .map(AsyncRwLock::new)
592         .collect();
593 
594     let (tx_send, mut tx_recv) = mpsc::unbounded();
595     let (rx_send, mut rx_recv) = mpsc::unbounded();
596     let tx_send_clone = tx_send.clone();
597     let rx_send_clone = rx_send.clone();
598     let restore_task = ex.spawn_local(async move {
599         if let Some(states) = &streams_state {
600             let ex = Executor::new().expect("Failed to create an executor");
601             for (stream, state) in streams.iter().zip(states.iter()) {
602                 stream.lock().await.restore(state);
603                 if state.state == VIRTIO_SND_R_PCM_START || state.state == VIRTIO_SND_R_PCM_PREPARE
604                 {
605                     stream
606                         .lock()
607                         .await
608                         .prepare(&ex, &tx_send_clone, &rx_send_clone)
609                         .await
610                         .expect("failed to prepare PCM");
611                 }
612                 if state.state == VIRTIO_SND_R_PCM_START {
613                     stream
614                         .lock()
615                         .await
616                         .start()
617                         .await
618                         .expect("failed to start PCM");
619                 }
620             }
621         }
622         streams
623     });
624     let streams = ex
625         .run_until(restore_task)
626         .expect("failed to restore streams");
627     let streams = Rc::new(AsyncRwLock::new(streams));
628 
629     let mut queues: Vec<(Queue, EventAsync)> = queues
630         .into_values()
631         .map(|q| {
632             let e = q.event().try_clone().expect("Failed to clone queue event");
633             (
634                 q,
635                 EventAsync::new(e, &ex).expect("Failed to create async event for queue"),
636             )
637         })
638         .collect();
639 
640     let (ctrl_queue, mut ctrl_queue_evt) = queues.remove(0);
641     let ctrl_queue = Rc::new(AsyncRwLock::new(ctrl_queue));
642     let (_event_queue, _event_queue_evt) = queues.remove(0);
643     let (tx_queue, tx_queue_evt) = queues.remove(0);
644     let (rx_queue, rx_queue_evt) = queues.remove(0);
645 
646     let tx_queue = Rc::new(AsyncRwLock::new(tx_queue));
647     let rx_queue = Rc::new(AsyncRwLock::new(rx_queue));
648 
649     let f_resample = async_utils::handle_irq_resample(&ex, interrupt.clone()).fuse();
650 
651     // Exit if the kill event is triggered.
652     let f_kill = async_utils::await_and_exit(&ex, kill_evt).fuse();
653 
654     pin_mut!(f_resample, f_kill);
655 
656     loop {
657         if run_worker_once(
658             &ex,
659             &streams,
660             &snd_data,
661             &mut f_kill,
662             &mut f_resample,
663             ctrl_queue.clone(),
664             &mut ctrl_queue_evt,
665             tx_queue.clone(),
666             &tx_queue_evt,
667             tx_send.clone(),
668             &mut tx_recv,
669             rx_queue.clone(),
670             &rx_queue_evt,
671             rx_send.clone(),
672             &mut rx_recv,
673             card_index,
674         ) == LoopState::Break
675         {
676             break;
677         }
678 
679         if let Err(e) = reset_streams(
680             &ex,
681             &streams,
682             &tx_queue,
683             &mut tx_recv,
684             &rx_queue,
685             &mut rx_recv,
686         ) {
687             error!("Error reset streams: {}", e);
688             break;
689         }
690     }
691     let streams_state_task = ex.spawn_local(async move {
692         let mut v = Vec::new();
693         for stream in streams.read_lock().await.iter() {
694             v.push(stream.read_lock().await.snapshot());
695         }
696         v
697     });
698     let streams_state = ex
699         .run_until(streams_state_task)
700         .expect("failed to save streams state");
701     let ctrl_queue = match Rc::try_unwrap(ctrl_queue) {
702         Ok(q) => q.into_inner(),
703         Err(_) => panic!("Too many refs to ctrl_queue"),
704     };
705     let tx_queue = match Rc::try_unwrap(tx_queue) {
706         Ok(q) => q.into_inner(),
707         Err(_) => panic!("Too many refs to tx_queue"),
708     };
709     let rx_queue = match Rc::try_unwrap(rx_queue) {
710         Ok(q) => q.into_inner(),
711         Err(_) => panic!("Too many refs to rx_queue"),
712     };
713     let queues = vec![ctrl_queue, _event_queue, tx_queue, rx_queue];
714 
715     Ok(WorkerReturn {
716         queues,
717         snd_data,
718         streams_state,
719     })
720 }
721 
722 struct WorkerReturn {
723     queues: Vec<Queue>,
724     snd_data: SndData,
725     streams_state: Vec<StreamInfoSnapshot>,
726 }
727 
notify_reset_signal(reset_signal: &(AsyncRwLock<bool>, Condvar))728 async fn notify_reset_signal(reset_signal: &(AsyncRwLock<bool>, Condvar)) {
729     let (lock, cvar) = reset_signal;
730     *lock.lock().await = true;
731     cvar.notify_all();
732 }
733 
734 /// Runs all workers once and exit if any worker exit.
735 ///
736 /// Returns [`LoopState::Break`] if the worker `f_kill` or `f_resample` exit, or something went
737 /// wrong on shutdown process. The caller should not run the worker again and should exit the main
738 /// loop.
739 ///
740 /// If this function returns [`LoopState::Continue`], the caller can continue the main loop by
741 /// resetting the streams and run the worker again.
run_worker_once( ex: &Executor, streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>, snd_data: &SndData, mut f_kill: &mut (impl FusedFuture<Output = anyhow::Result<()>> + Unpin), mut f_resample: &mut (impl FusedFuture<Output = anyhow::Result<()>> + Unpin), ctrl_queue: Rc<AsyncRwLock<Queue>>, ctrl_queue_evt: &mut EventAsync, tx_queue: Rc<AsyncRwLock<Queue>>, tx_queue_evt: &EventAsync, tx_send: mpsc::UnboundedSender<PcmResponse>, tx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>, rx_queue: Rc<AsyncRwLock<Queue>>, rx_queue_evt: &EventAsync, rx_send: mpsc::UnboundedSender<PcmResponse>, rx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>, card_index: usize, ) -> LoopState742 fn run_worker_once(
743     ex: &Executor,
744     streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>,
745     snd_data: &SndData,
746     mut f_kill: &mut (impl FusedFuture<Output = anyhow::Result<()>> + Unpin),
747     mut f_resample: &mut (impl FusedFuture<Output = anyhow::Result<()>> + Unpin),
748     ctrl_queue: Rc<AsyncRwLock<Queue>>,
749     ctrl_queue_evt: &mut EventAsync,
750     tx_queue: Rc<AsyncRwLock<Queue>>,
751     tx_queue_evt: &EventAsync,
752     tx_send: mpsc::UnboundedSender<PcmResponse>,
753     tx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
754     rx_queue: Rc<AsyncRwLock<Queue>>,
755     rx_queue_evt: &EventAsync,
756     rx_send: mpsc::UnboundedSender<PcmResponse>,
757     rx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
758     card_index: usize,
759 ) -> LoopState {
760     let tx_send2 = tx_send.clone();
761     let rx_send2 = rx_send.clone();
762 
763     let reset_signal = (AsyncRwLock::new(false), Condvar::new());
764 
765     let f_ctrl = handle_ctrl_queue(
766         ex,
767         streams,
768         snd_data,
769         ctrl_queue,
770         ctrl_queue_evt,
771         tx_send,
772         rx_send,
773         card_index,
774         Some(&reset_signal),
775     )
776     .fuse();
777 
778     // TODO(woodychow): Enable this when libcras sends jack connect/disconnect evts
779     // let f_event = handle_event_queue(
780     //     snd_state,
781     //     event_queue,
782     //     event_queue_evt,
783     //     interrupt,
784     // );
785     let f_tx = handle_pcm_queue(
786         streams,
787         tx_send2,
788         tx_queue.clone(),
789         tx_queue_evt,
790         card_index,
791         Some(&reset_signal),
792     )
793     .fuse();
794     let f_tx_response = send_pcm_response_worker(tx_queue, tx_recv, Some(&reset_signal)).fuse();
795     let f_rx = handle_pcm_queue(
796         streams,
797         rx_send2,
798         rx_queue.clone(),
799         rx_queue_evt,
800         card_index,
801         Some(&reset_signal),
802     )
803     .fuse();
804     let f_rx_response = send_pcm_response_worker(rx_queue, rx_recv, Some(&reset_signal)).fuse();
805 
806     pin_mut!(f_ctrl, f_tx, f_tx_response, f_rx, f_rx_response);
807 
808     let done = async {
809         select! {
810             res = f_ctrl => (res.context("error in handling ctrl queue"), LoopState::Continue),
811             res = f_tx => (res.context("error in handling tx queue"), LoopState::Continue),
812             res = f_tx_response => (res.context("error in handling tx response"), LoopState::Continue),
813             res = f_rx => (res.context("error in handling rx queue"), LoopState::Continue),
814             res = f_rx_response => (res.context("error in handling rx response"), LoopState::Continue),
815 
816             // For following workers, do not continue the loop
817             res = f_resample => (res.context("error in handle_irq_resample"), LoopState::Break),
818             res = f_kill => (res.context("error in await_and_exit"), LoopState::Break),
819         }
820     };
821 
822     match ex.run_until(done) {
823         Ok((res, loop_state)) => {
824             if let Err(e) = res {
825                 error!("Error in worker: {:#}", e);
826             }
827             if loop_state == LoopState::Break {
828                 return LoopState::Break;
829             }
830         }
831         Err(e) => {
832             error!("Error happened in executor: {}", e);
833         }
834     }
835 
836     warn!("Shutting down all workers for reset procedure");
837     block_on(notify_reset_signal(&reset_signal));
838 
839     let shutdown = async {
840         loop {
841             let (res, worker_name) = select!(
842                 res = f_ctrl => (res, "f_ctrl"),
843                 res = f_tx => (res, "f_tx"),
844                 res = f_tx_response => (res, "f_tx_response"),
845                 res = f_rx => (res, "f_rx"),
846                 res = f_rx_response => (res, "f_rx_response"),
847                 complete => break,
848             );
849             match res {
850                 Ok(_) => debug!("Worker {} stopped", worker_name),
851                 Err(e) => error!("Worker {} stopped with error {}", worker_name, e),
852             };
853         }
854     };
855 
856     if let Err(e) = ex.run_until(shutdown) {
857         error!("Error happened in executor while shutdown: {}", e);
858         return LoopState::Break;
859     }
860 
861     LoopState::Continue
862 }
863 
reset_streams( ex: &Executor, streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>, tx_queue: &Rc<AsyncRwLock<Queue>>, tx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>, rx_queue: &Rc<AsyncRwLock<Queue>>, rx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>, ) -> Result<(), AsyncError>864 fn reset_streams(
865     ex: &Executor,
866     streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>,
867     tx_queue: &Rc<AsyncRwLock<Queue>>,
868     tx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
869     rx_queue: &Rc<AsyncRwLock<Queue>>,
870     rx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
871 ) -> Result<(), AsyncError> {
872     let reset_signal = (AsyncRwLock::new(false), Condvar::new());
873 
874     let do_reset = async {
875         let streams = streams.read_lock().await;
876         for stream_info in &*streams {
877             let mut stream_info = stream_info.lock().await;
878             if stream_info.state == VIRTIO_SND_R_PCM_START {
879                 if let Err(e) = stream_info.stop().await {
880                     error!("Error on stop while resetting stream: {}", e);
881                 }
882             }
883             if stream_info.state == VIRTIO_SND_R_PCM_STOP
884                 || stream_info.state == VIRTIO_SND_R_PCM_PREPARE
885             {
886                 if let Err(e) = stream_info.release().await {
887                     error!("Error on release while resetting stream: {}", e);
888                 }
889             }
890             stream_info.just_reset = true;
891         }
892 
893         notify_reset_signal(&reset_signal).await;
894     };
895 
896     // Run these in a loop to ensure that they will survive until do_reset is finished
897     let f_tx_response = async {
898         while send_pcm_response_worker(tx_queue.clone(), tx_recv, Some(&reset_signal))
899             .await
900             .is_err()
901         {}
902     };
903 
904     let f_rx_response = async {
905         while send_pcm_response_worker(rx_queue.clone(), rx_recv, Some(&reset_signal))
906             .await
907             .is_err()
908         {}
909     };
910 
911     let reset = async {
912         join!(f_tx_response, f_rx_response, do_reset);
913     };
914 
915     ex.run_until(reset)
916 }
917 
918 #[cfg(test)]
919 #[allow(clippy::needless_update)]
920 mod tests {
921     use audio_streams::StreamEffect;
922 
923     use super::*;
924     use crate::virtio::snd::parameters::PCMDeviceParameters;
925 
926     #[test]
test_virtio_snd_new()927     fn test_virtio_snd_new() {
928         let params = Parameters {
929             num_output_devices: 3,
930             num_input_devices: 2,
931             num_output_streams: 3,
932             num_input_streams: 2,
933             output_device_config: vec![PCMDeviceParameters {
934                 effects: Some(vec![StreamEffect::EchoCancellation]),
935                 ..PCMDeviceParameters::default()
936             }],
937             input_device_config: vec![PCMDeviceParameters {
938                 effects: Some(vec![StreamEffect::EchoCancellation]),
939                 ..PCMDeviceParameters::default()
940             }],
941             ..Default::default()
942         };
943 
944         let res = VirtioSnd::new(123, params).unwrap();
945 
946         // Default values
947         assert_eq!(res.snd_data.jack_info.len(), 0);
948         assert_eq!(res.acked_features, 0);
949         assert_eq!(res.worker_thread.is_none(), true);
950 
951         assert_eq!(res.avail_features, 123); // avail_features must be equal to the input
952         assert_eq!(res.cfg.jacks.to_native(), 0);
953         assert_eq!(res.cfg.streams.to_native(), 13); // (Output = 3*3) + (Input = 2*2)
954         assert_eq!(res.cfg.chmaps.to_native(), 11); // (Output = 3*3) + (Input = 2*1)
955 
956         // Check snd_data.pcm_info
957         assert_eq!(res.snd_data.pcm_info.len(), 13);
958         // Check hda_fn_nid (PCM Device number)
959         let expected_hda_fn_nid = [0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 1, 1];
960         for (i, pcm_info) in res.snd_data.pcm_info.iter().enumerate() {
961             assert_eq!(
962                 pcm_info.hdr.hda_fn_nid.to_native(),
963                 expected_hda_fn_nid[i],
964                 "pcm_info index {} incorrect hda_fn_nid",
965                 i
966             );
967         }
968         // First 9 devices must be OUTPUT
969         for i in 0..9 {
970             assert_eq!(
971                 res.snd_data.pcm_info[i].direction, VIRTIO_SND_D_OUTPUT,
972                 "pcm_info index {} incorrect direction",
973                 i
974             );
975         }
976         // Next 4 devices must be INPUT
977         for i in 9..13 {
978             assert_eq!(
979                 res.snd_data.pcm_info[i].direction, VIRTIO_SND_D_INPUT,
980                 "pcm_info index {} incorrect direction",
981                 i
982             );
983         }
984 
985         // Check snd_data.chmap_info
986         assert_eq!(res.snd_data.chmap_info.len(), 11);
987         let expected_hda_fn_nid = [0, 1, 2, 0, 1, 0, 1, 2, 0, 1, 2];
988         // Check hda_fn_nid (PCM Device number)
989         for (i, chmap_info) in res.snd_data.chmap_info.iter().enumerate() {
990             assert_eq!(
991                 chmap_info.hdr.hda_fn_nid.to_native(),
992                 expected_hda_fn_nid[i],
993                 "chmap_info index {} incorrect hda_fn_nid",
994                 i
995             );
996         }
997     }
998 
999     #[test]
test_resize_parameters_pcm_device_config_truncate()1000     fn test_resize_parameters_pcm_device_config_truncate() {
1001         // If pcm_device_config is larger than number of devices, it will be truncated
1002         let params = Parameters {
1003             num_output_devices: 1,
1004             num_input_devices: 1,
1005             output_device_config: vec![PCMDeviceParameters::default(); 3],
1006             input_device_config: vec![PCMDeviceParameters::default(); 3],
1007             ..Parameters::default()
1008         };
1009         let params = resize_parameters_pcm_device_config(params);
1010         assert_eq!(params.output_device_config.len(), 1);
1011         assert_eq!(params.input_device_config.len(), 1);
1012     }
1013 
1014     #[test]
test_resize_parameters_pcm_device_config_extend()1015     fn test_resize_parameters_pcm_device_config_extend() {
1016         let params = Parameters {
1017             num_output_devices: 3,
1018             num_input_devices: 2,
1019             num_output_streams: 3,
1020             num_input_streams: 2,
1021             output_device_config: vec![PCMDeviceParameters {
1022                 effects: Some(vec![StreamEffect::EchoCancellation]),
1023                 ..PCMDeviceParameters::default()
1024             }],
1025             input_device_config: vec![PCMDeviceParameters {
1026                 effects: Some(vec![StreamEffect::EchoCancellation]),
1027                 ..PCMDeviceParameters::default()
1028             }],
1029             ..Default::default()
1030         };
1031 
1032         let params = resize_parameters_pcm_device_config(params);
1033 
1034         // Check output_device_config correctly extended
1035         assert_eq!(
1036             params.output_device_config,
1037             vec![
1038                 PCMDeviceParameters {
1039                     // Keep from the parameters
1040                     effects: Some(vec![StreamEffect::EchoCancellation]),
1041                     ..PCMDeviceParameters::default()
1042                 },
1043                 PCMDeviceParameters::default(), // Extended with default
1044                 PCMDeviceParameters::default(), // Extended with default
1045             ]
1046         );
1047 
1048         // Check input_device_config correctly extended
1049         assert_eq!(
1050             params.input_device_config,
1051             vec![
1052                 PCMDeviceParameters {
1053                     // Keep from the parameters
1054                     effects: Some(vec![StreamEffect::EchoCancellation]),
1055                     ..PCMDeviceParameters::default()
1056                 },
1057                 PCMDeviceParameters::default(), // Extended with default
1058             ]
1059         );
1060     }
1061 }
1062