xref: /aosp_15_r20/external/crosvm/common/audio_streams/src/shm_streams.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2019 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #[cfg(any(target_os = "android", target_os = "linux"))]
6 use std::os::unix::io::RawFd;
7 use std::sync::Arc;
8 use std::sync::Condvar;
9 use std::sync::Mutex;
10 use std::time::Duration;
11 use std::time::Instant;
12 
13 use remain::sorted;
14 use thiserror::Error;
15 
16 use crate::BoxError;
17 use crate::SampleFormat;
18 use crate::StreamDirection;
19 use crate::StreamEffect;
20 
21 type GenericResult<T> = std::result::Result<T, BoxError>;
22 
23 /// `BufferSet` is used as a callback mechanism for `ServerRequest` objects.
24 /// It is meant to be implemented by the audio stream, allowing arbitrary code
25 /// to be run after a buffer offset and length is set.
26 pub trait BufferSet {
27     /// Called when the client sets a buffer offset and length.
28     ///
29     /// `offset` is the offset within shared memory of the buffer and `frames`
30     /// indicates the number of audio frames that can be read from or written to
31     /// the buffer.
callback(&mut self, offset: usize, frames: usize) -> GenericResult<()>32     fn callback(&mut self, offset: usize, frames: usize) -> GenericResult<()>;
33 
34     /// Called when the client ignores a request from the server.
ignore(&mut self) -> GenericResult<()>35     fn ignore(&mut self) -> GenericResult<()>;
36 }
37 
38 #[sorted]
39 #[derive(Error, Debug)]
40 pub enum Error {
41     #[error("Provided number of frames {0} exceeds requested number of frames {1}")]
42     TooManyFrames(usize, usize),
43 }
44 
45 /// `ServerRequest` represents an active request from the server for the client
46 /// to provide a buffer in shared memory to playback from or capture to.
47 pub struct ServerRequest<'a> {
48     requested_frames: usize,
49     buffer_set: &'a mut dyn BufferSet,
50 }
51 
52 impl<'a> ServerRequest<'a> {
53     /// Create a new ServerRequest object
54     ///
55     /// Create a ServerRequest object representing a request from the server
56     /// for a buffer `requested_frames` in size.
57     ///
58     /// When the client responds to this request by calling
59     /// [`set_buffer_offset_and_frames`](ServerRequest::set_buffer_offset_and_frames),
60     /// BufferSet::callback will be called on `buffer_set`.
61     ///
62     /// # Arguments
63     /// * `requested_frames` - The requested buffer size in frames.
64     /// * `buffer_set` - The object implementing the callback for when a buffer is provided.
new<D: BufferSet>(requested_frames: usize, buffer_set: &'a mut D) -> Self65     pub fn new<D: BufferSet>(requested_frames: usize, buffer_set: &'a mut D) -> Self {
66         Self {
67             requested_frames,
68             buffer_set,
69         }
70     }
71 
72     /// Get the number of frames of audio data requested by the server.
73     ///
74     /// The returned value should never be greater than the `buffer_size`
75     /// given in [`new_stream`](ShmStreamSource::new_stream).
requested_frames(&self) -> usize76     pub fn requested_frames(&self) -> usize {
77         self.requested_frames
78     }
79 
80     /// Sets the buffer offset and length for the requested buffer.
81     ///
82     /// Sets the buffer offset and length of the buffer that fulfills this
83     /// server request to `offset` and `length`, respectively. This means that
84     /// `length` bytes of audio samples may be read from/written to that
85     /// location in `client_shm` for a playback/capture stream, respectively.
86     /// This function may only be called once for a `ServerRequest`, at which
87     /// point the ServerRequest is dropped and no further calls are possible.
88     ///
89     /// # Arguments
90     ///
91     /// * `offset` - The value to use as the new buffer offset for the next buffer.
92     /// * `frames` - The length of the next buffer in frames.
93     ///
94     /// # Errors
95     ///
96     /// * If `frames` is greater than `requested_frames`.
set_buffer_offset_and_frames(self, offset: usize, frames: usize) -> GenericResult<()>97     pub fn set_buffer_offset_and_frames(self, offset: usize, frames: usize) -> GenericResult<()> {
98         if frames > self.requested_frames {
99             return Err(Box::new(Error::TooManyFrames(
100                 frames,
101                 self.requested_frames,
102             )));
103         }
104 
105         self.buffer_set.callback(offset, frames)
106     }
107 
108     /// Ignore this request
109     ///
110     /// If the client does not intend to respond to this ServerRequest with a
111     /// buffer, they should call this function. The stream will be notified that
112     /// the request has been ignored and will handle it properly.
ignore_request(self) -> GenericResult<()>113     pub fn ignore_request(self) -> GenericResult<()> {
114         self.buffer_set.ignore()
115     }
116 }
117 
118 /// `ShmStream` allows a client to interact with an active CRAS stream.
119 pub trait ShmStream: Send {
120     /// Get the size of a frame of audio data for this stream.
frame_size(&self) -> usize121     fn frame_size(&self) -> usize;
122 
123     /// Get the number of channels of audio data for this stream.
num_channels(&self) -> usize124     fn num_channels(&self) -> usize;
125 
126     /// Get the frame rate of audio data for this stream.
frame_rate(&self) -> u32127     fn frame_rate(&self) -> u32;
128 
129     /// Waits until the next server message indicating action is required.
130     ///
131     /// For playback streams, this will be `AUDIO_MESSAGE_REQUEST_DATA`, meaning
132     /// that we must set the buffer offset to the next location where playback
133     /// data can be found.
134     /// For capture streams, this will be `AUDIO_MESSAGE_DATA_READY`, meaning
135     /// that we must set the buffer offset to the next location where captured
136     /// data can be written to.
137     /// Will return early if `timeout` elapses before a message is received.
138     ///
139     /// # Arguments
140     ///
141     /// * `timeout` - The amount of time to wait until a message is received.
142     ///
143     /// # Return value
144     ///
145     /// Returns `Some(request)` where `request` is an object that implements the
146     /// [`ServerRequest`] trait and which can be used to get the
147     /// number of bytes requested for playback streams or that have already been
148     /// written to shm for capture streams.
149     ///
150     /// If the timeout occurs before a message is received, returns `None`.
151     ///
152     /// # Errors
153     ///
154     /// * If an invalid message type is received for the stream.
wait_for_next_action_with_timeout( &mut self, timeout: Duration, ) -> GenericResult<Option<ServerRequest>>155     fn wait_for_next_action_with_timeout(
156         &mut self,
157         timeout: Duration,
158     ) -> GenericResult<Option<ServerRequest>>;
159 }
160 
161 /// `SharedMemory` specifies features of shared memory areas passed on to `ShmStreamSource`.
162 pub trait SharedMemory {
163     type Error: std::error::Error;
164 
165     /// Creates a new shared memory file descriptor without specifying a name.
anon(size: u64) -> Result<Self, Self::Error> where Self: Sized166     fn anon(size: u64) -> Result<Self, Self::Error>
167     where
168         Self: Sized;
169 
170     /// Gets the size in bytes of the shared memory.
171     ///
172     /// The size returned here does not reflect changes by other interfaces or users of the shared
173     /// memory file descriptor..
size(&self) -> u64174     fn size(&self) -> u64;
175 
176     /// Returns the underlying raw fd.
177     #[cfg(any(target_os = "android", target_os = "linux"))]
as_raw_fd(&self) -> RawFd178     fn as_raw_fd(&self) -> RawFd;
179 }
180 
181 /// `ShmStreamSource` creates streams for playback or capture of audio.
182 pub trait ShmStreamSource<E: std::error::Error>: Send {
183     /// Creates a new [`ShmStream`]
184     ///
185     /// Creates a new `ShmStream` object, which allows:
186     /// * Waiting until the server has communicated that data is ready or requested that we make
187     ///   more data available.
188     /// * Setting the location and length of buffers for reading/writing audio data.
189     ///
190     /// # Arguments
191     ///
192     /// * `direction` - The direction of the stream, either `Playback` or `Capture`.
193     /// * `num_channels` - The number of audio channels for the stream.
194     /// * `format` - The audio format to use for audio samples.
195     /// * `frame_rate` - The stream's frame rate in Hz.
196     /// * `buffer_size` - The maximum size of an audio buffer. This will be the size used for
197     ///   transfers of audio data between client and server.
198     /// * `effects` - Audio effects to use for the stream, such as echo-cancellation.
199     /// * `client_shm` - The shared memory area that will contain samples.
200     /// * `buffer_offsets` - The two initial values to use as buffer offsets for streams. This way,
201     ///   the server will not write audio data to an arbitrary offset in `client_shm` if the client
202     ///   fails to update offsets in time.
203     ///
204     /// # Errors
205     ///
206     /// * If sending the connect stream message to the server fails.
207     #[allow(clippy::too_many_arguments)]
new_stream( &mut self, direction: StreamDirection, num_channels: usize, format: SampleFormat, frame_rate: u32, buffer_size: usize, effects: &[StreamEffect], client_shm: &dyn SharedMemory<Error = E>, buffer_offsets: [u64; 2], ) -> GenericResult<Box<dyn ShmStream>>208     fn new_stream(
209         &mut self,
210         direction: StreamDirection,
211         num_channels: usize,
212         format: SampleFormat,
213         frame_rate: u32,
214         buffer_size: usize,
215         effects: &[StreamEffect],
216         client_shm: &dyn SharedMemory<Error = E>,
217         buffer_offsets: [u64; 2],
218     ) -> GenericResult<Box<dyn ShmStream>>;
219 
220     /// Get a list of file descriptors used by the implementation.
221     ///
222     /// Returns any open file descriptors needed by the implementation.
223     /// This list helps users of the ShmStreamSource enter Linux jails without
224     /// closing needed file descriptors.
225     #[cfg(any(target_os = "android", target_os = "linux"))]
keep_fds(&self) -> Vec<RawFd>226     fn keep_fds(&self) -> Vec<RawFd> {
227         Vec::new()
228     }
229 }
230 
231 /// Class that implements ShmStream trait but does nothing with the samples
232 pub struct NullShmStream {
233     num_channels: usize,
234     frame_rate: u32,
235     buffer_size: usize,
236     frame_size: usize,
237     interval: Duration,
238     next_frame: Duration,
239     start_time: Instant,
240 }
241 
242 impl NullShmStream {
243     /// Attempt to create a new NullShmStream with the given number of channels,
244     /// format, frame_rate, and buffer_size.
new( buffer_size: usize, num_channels: usize, format: SampleFormat, frame_rate: u32, ) -> Self245     pub fn new(
246         buffer_size: usize,
247         num_channels: usize,
248         format: SampleFormat,
249         frame_rate: u32,
250     ) -> Self {
251         let interval = Duration::from_millis(buffer_size as u64 * 1000 / frame_rate as u64);
252         Self {
253             num_channels,
254             frame_rate,
255             buffer_size,
256             frame_size: format.sample_bytes() * num_channels,
257             interval,
258             next_frame: interval,
259             start_time: Instant::now(),
260         }
261     }
262 }
263 
264 impl BufferSet for NullShmStream {
callback(&mut self, _offset: usize, _frames: usize) -> GenericResult<()>265     fn callback(&mut self, _offset: usize, _frames: usize) -> GenericResult<()> {
266         Ok(())
267     }
268 
ignore(&mut self) -> GenericResult<()>269     fn ignore(&mut self) -> GenericResult<()> {
270         Ok(())
271     }
272 }
273 
274 impl ShmStream for NullShmStream {
frame_size(&self) -> usize275     fn frame_size(&self) -> usize {
276         self.frame_size
277     }
278 
num_channels(&self) -> usize279     fn num_channels(&self) -> usize {
280         self.num_channels
281     }
282 
frame_rate(&self) -> u32283     fn frame_rate(&self) -> u32 {
284         self.frame_rate
285     }
286 
wait_for_next_action_with_timeout( &mut self, timeout: Duration, ) -> GenericResult<Option<ServerRequest>>287     fn wait_for_next_action_with_timeout(
288         &mut self,
289         timeout: Duration,
290     ) -> GenericResult<Option<ServerRequest>> {
291         let elapsed = self.start_time.elapsed();
292         if elapsed < self.next_frame {
293             if timeout < self.next_frame - elapsed {
294                 std::thread::sleep(timeout);
295                 return Ok(None);
296             } else {
297                 std::thread::sleep(self.next_frame - elapsed);
298             }
299         }
300         self.next_frame += self.interval;
301         Ok(Some(ServerRequest::new(self.buffer_size, self)))
302     }
303 }
304 
305 /// Source of `NullShmStream` objects.
306 #[derive(Default)]
307 pub struct NullShmStreamSource;
308 
309 impl NullShmStreamSource {
new() -> Self310     pub fn new() -> Self {
311         NullShmStreamSource
312     }
313 }
314 
315 impl<E: std::error::Error> ShmStreamSource<E> for NullShmStreamSource {
new_stream( &mut self, _direction: StreamDirection, num_channels: usize, format: SampleFormat, frame_rate: u32, buffer_size: usize, _effects: &[StreamEffect], _client_shm: &dyn SharedMemory<Error = E>, _buffer_offsets: [u64; 2], ) -> GenericResult<Box<dyn ShmStream>>316     fn new_stream(
317         &mut self,
318         _direction: StreamDirection,
319         num_channels: usize,
320         format: SampleFormat,
321         frame_rate: u32,
322         buffer_size: usize,
323         _effects: &[StreamEffect],
324         _client_shm: &dyn SharedMemory<Error = E>,
325         _buffer_offsets: [u64; 2],
326     ) -> GenericResult<Box<dyn ShmStream>> {
327         let new_stream = NullShmStream::new(buffer_size, num_channels, format, frame_rate);
328         Ok(Box::new(new_stream))
329     }
330 }
331 
332 #[derive(Clone)]
333 pub struct MockShmStream {
334     num_channels: usize,
335     frame_rate: u32,
336     request_size: usize,
337     frame_size: usize,
338     request_notifier: Arc<(Mutex<bool>, Condvar)>,
339 }
340 
341 impl MockShmStream {
342     /// Attempt to create a new MockShmStream with the given number of
343     /// channels, frame_rate, format, and buffer_size.
new( num_channels: usize, frame_rate: u32, format: SampleFormat, buffer_size: usize, ) -> Self344     pub fn new(
345         num_channels: usize,
346         frame_rate: u32,
347         format: SampleFormat,
348         buffer_size: usize,
349     ) -> Self {
350         #[allow(clippy::mutex_atomic)]
351         Self {
352             num_channels,
353             frame_rate,
354             request_size: buffer_size,
355             frame_size: format.sample_bytes() * num_channels,
356             request_notifier: Arc::new((Mutex::new(false), Condvar::new())),
357         }
358     }
359 
360     /// Call to request data from the stream, causing it to return from
361     /// `wait_for_next_action_with_timeout`. Will block until
362     /// `set_buffer_offset_and_frames` is called on the ServerRequest returned
363     /// from `wait_for_next_action_with_timeout`, or until `timeout` elapses.
364     /// Returns true if a response was successfully received.
trigger_callback_with_timeout(&mut self, timeout: Duration) -> bool365     pub fn trigger_callback_with_timeout(&mut self, timeout: Duration) -> bool {
366         let (lock, cvar) = &*self.request_notifier;
367         let mut requested = lock.lock().unwrap();
368         *requested = true;
369         cvar.notify_one();
370         let start_time = Instant::now();
371         while *requested {
372             requested = cvar.wait_timeout(requested, timeout).unwrap().0;
373             if start_time.elapsed() > timeout {
374                 // We failed to get a callback in time, mark this as false.
375                 *requested = false;
376                 return false;
377             }
378         }
379 
380         true
381     }
382 
notify_request(&mut self)383     fn notify_request(&mut self) {
384         let (lock, cvar) = &*self.request_notifier;
385         let mut requested = lock.lock().unwrap();
386         *requested = false;
387         cvar.notify_one();
388     }
389 }
390 
391 impl BufferSet for MockShmStream {
callback(&mut self, _offset: usize, _frames: usize) -> GenericResult<()>392     fn callback(&mut self, _offset: usize, _frames: usize) -> GenericResult<()> {
393         self.notify_request();
394         Ok(())
395     }
396 
ignore(&mut self) -> GenericResult<()>397     fn ignore(&mut self) -> GenericResult<()> {
398         self.notify_request();
399         Ok(())
400     }
401 }
402 
403 impl ShmStream for MockShmStream {
frame_size(&self) -> usize404     fn frame_size(&self) -> usize {
405         self.frame_size
406     }
407 
num_channels(&self) -> usize408     fn num_channels(&self) -> usize {
409         self.num_channels
410     }
411 
frame_rate(&self) -> u32412     fn frame_rate(&self) -> u32 {
413         self.frame_rate
414     }
415 
wait_for_next_action_with_timeout( &mut self, timeout: Duration, ) -> GenericResult<Option<ServerRequest>>416     fn wait_for_next_action_with_timeout(
417         &mut self,
418         timeout: Duration,
419     ) -> GenericResult<Option<ServerRequest>> {
420         {
421             let start_time = Instant::now();
422             let (lock, cvar) = &*self.request_notifier;
423             let mut requested = lock.lock().unwrap();
424             while !*requested {
425                 requested = cvar.wait_timeout(requested, timeout).unwrap().0;
426                 if start_time.elapsed() > timeout {
427                     return Ok(None);
428                 }
429             }
430         }
431 
432         Ok(Some(ServerRequest::new(self.request_size, self)))
433     }
434 }
435 
436 /// Source of `MockShmStream` objects.
437 #[derive(Clone, Default)]
438 pub struct MockShmStreamSource {
439     last_stream: Arc<(Mutex<Option<MockShmStream>>, Condvar)>,
440 }
441 
442 impl MockShmStreamSource {
new() -> Self443     pub fn new() -> Self {
444         Default::default()
445     }
446 
447     /// Get the last stream that has been created from this source. If no stream
448     /// has been created, block until one has.
get_last_stream(&self) -> MockShmStream449     pub fn get_last_stream(&self) -> MockShmStream {
450         let (last_stream, cvar) = &*self.last_stream;
451         let mut stream = last_stream.lock().unwrap();
452         loop {
453             match &*stream {
454                 None => stream = cvar.wait(stream).unwrap(),
455                 Some(ref s) => return s.clone(),
456             };
457         }
458     }
459 }
460 
461 impl<E: std::error::Error> ShmStreamSource<E> for MockShmStreamSource {
new_stream( &mut self, _direction: StreamDirection, num_channels: usize, format: SampleFormat, frame_rate: u32, buffer_size: usize, _effects: &[StreamEffect], _client_shm: &dyn SharedMemory<Error = E>, _buffer_offsets: [u64; 2], ) -> GenericResult<Box<dyn ShmStream>>462     fn new_stream(
463         &mut self,
464         _direction: StreamDirection,
465         num_channels: usize,
466         format: SampleFormat,
467         frame_rate: u32,
468         buffer_size: usize,
469         _effects: &[StreamEffect],
470         _client_shm: &dyn SharedMemory<Error = E>,
471         _buffer_offsets: [u64; 2],
472     ) -> GenericResult<Box<dyn ShmStream>> {
473         let (last_stream, cvar) = &*self.last_stream;
474         let mut stream = last_stream.lock().unwrap();
475 
476         let new_stream = MockShmStream::new(num_channels, frame_rate, format, buffer_size);
477         *stream = Some(new_stream.clone());
478         cvar.notify_one();
479         Ok(Box::new(new_stream))
480     }
481 }
482 
483 // Tests that run only for Unix, where `base::SharedMemory` is used.
484 #[cfg(all(test, unix))]
485 pub mod tests {
486     use super::*;
487 
488     struct MockSharedMemory {}
489 
490     impl SharedMemory for MockSharedMemory {
491         type Error = super::Error;
492 
anon(_: u64) -> Result<Self, Self::Error>493         fn anon(_: u64) -> Result<Self, Self::Error> {
494             Ok(MockSharedMemory {})
495         }
496 
size(&self) -> u64497         fn size(&self) -> u64 {
498             0
499         }
500 
501         #[cfg(any(target_os = "android", target_os = "linux"))]
as_raw_fd(&self) -> RawFd502         fn as_raw_fd(&self) -> RawFd {
503             0
504         }
505     }
506 
507     #[test]
mock_trigger_callback()508     fn mock_trigger_callback() {
509         let stream_source = MockShmStreamSource::new();
510         let mut thread_stream_source = stream_source.clone();
511 
512         let buffer_size = 480;
513         let num_channels = 2;
514         let format = SampleFormat::S24LE;
515         let shm = MockSharedMemory {};
516 
517         let handle = std::thread::spawn(move || {
518             let mut stream = thread_stream_source
519                 .new_stream(
520                     StreamDirection::Playback,
521                     num_channels,
522                     format,
523                     44100,
524                     buffer_size,
525                     &[],
526                     &shm,
527                     [400, 8000],
528                 )
529                 .expect("Failed to create stream");
530 
531             let request = stream
532                 .wait_for_next_action_with_timeout(Duration::from_secs(5))
533                 .expect("Failed to wait for next action");
534             match request {
535                 Some(r) => {
536                     let requested = r.requested_frames();
537                     r.set_buffer_offset_and_frames(872, requested)
538                         .expect("Failed to set buffer offset and frames");
539                     requested
540                 }
541                 None => 0,
542             }
543         });
544 
545         let mut stream = stream_source.get_last_stream();
546         assert!(stream.trigger_callback_with_timeout(Duration::from_secs(1)));
547 
548         let requested_frames = handle.join().expect("Failed to join thread");
549         assert_eq!(requested_frames, buffer_size);
550     }
551 
552     #[test]
null_consumption_rate()553     fn null_consumption_rate() {
554         let frame_rate = 44100;
555         let buffer_size = 480;
556         let interval = Duration::from_millis(buffer_size as u64 * 1000 / frame_rate as u64);
557 
558         let shm = MockSharedMemory {};
559 
560         let start = Instant::now();
561 
562         let mut stream_source = NullShmStreamSource::new();
563         let mut stream = stream_source
564             .new_stream(
565                 StreamDirection::Playback,
566                 2,
567                 SampleFormat::S24LE,
568                 frame_rate,
569                 buffer_size,
570                 &[],
571                 &shm,
572                 [400, 8000],
573             )
574             .expect("Failed to create stream");
575 
576         let timeout = Duration::from_secs(5);
577         let request = stream
578             .wait_for_next_action_with_timeout(timeout)
579             .expect("Failed to wait for first request")
580             .expect("First request should not have timed out");
581         request
582             .set_buffer_offset_and_frames(276, 480)
583             .expect("Failed to set buffer offset and length");
584 
585         // The second call should block until the first buffer is consumed.
586         let _request = stream
587             .wait_for_next_action_with_timeout(timeout)
588             .expect("Failed to wait for second request");
589         let elapsed = start.elapsed();
590         assert!(
591             elapsed > interval,
592             "wait_for_next_action_with_timeout didn't block long enough: {:?}",
593             elapsed
594         );
595 
596         assert!(
597             elapsed < timeout,
598             "wait_for_next_action_with_timeout blocked for too long: {:?}",
599             elapsed
600         );
601     }
602 }
603