xref: /aosp_15_r20/external/crosvm/win_audio/src/win_audio_impl/device_notification.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2023 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::sync::atomic::AtomicBool;
6 use std::sync::atomic::AtomicU32;
7 use std::sync::atomic::Ordering::SeqCst;
8 use std::sync::Arc;
9 
10 use base::info;
11 use libc::c_void;
12 use winapi::shared::guiddef::IsEqualGUID;
13 use winapi::shared::guiddef::REFIID;
14 use winapi::shared::minwindef::DWORD;
15 use winapi::shared::minwindef::ULONG;
16 use winapi::shared::winerror::E_INVALIDARG;
17 use winapi::shared::winerror::E_NOINTERFACE;
18 use winapi::shared::winerror::NOERROR;
19 use winapi::shared::wtypes::PROPERTYKEY;
20 use winapi::um::mmdeviceapi::EDataFlow;
21 use winapi::um::mmdeviceapi::ERole;
22 use winapi::um::mmdeviceapi::IMMNotificationClient;
23 use winapi::um::mmdeviceapi::IMMNotificationClientVtbl;
24 use winapi::um::objidlbase::IAgileObject;
25 use winapi::um::unknwnbase::IUnknown;
26 use winapi::um::unknwnbase::IUnknownVtbl;
27 use winapi::um::winnt::HRESULT;
28 use winapi::um::winnt::LPCWSTR;
29 use winapi::Interface;
30 use wio::com::ComPtr;
31 
32 /// This device notification client will be used to notify win_audio when a new audio device is
33 /// available. This notification client will only be registered when there are
34 /// no audio devices detected.
35 #[repr(C)]
36 pub(crate) struct WinIMMNotificationClient {
37     pub lp_vtbl: &'static IMMNotificationClientVtbl,
38     ref_count: AtomicU32,
39     // Shared with `WinAudioRenderer`. This will used in `next_playback_buffer` only when
40     // `NoopStream` is being used. When this is set the `true`, `WinAudioRenderer` will attempt
41     // to create a new `DeviceRenderer`.
42     device_available: Arc<AtomicBool>,
43     data_flow: EDataFlow,
44 }
45 
46 impl WinIMMNotificationClient {
47     /// The ComPtr is a `WinIMMNotificationClient` casted as an `IMMNotificationClient`.
create_com_ptr( device_available: Arc<AtomicBool>, data_flow: EDataFlow, ) -> ComPtr<IMMNotificationClient>48     pub(crate) fn create_com_ptr(
49         device_available: Arc<AtomicBool>,
50         data_flow: EDataFlow,
51     ) -> ComPtr<IMMNotificationClient> {
52         let win_imm_notification_client = Box::new(WinIMMNotificationClient {
53             lp_vtbl: IMM_NOTIFICATION_CLIENT_VTBL,
54             ref_count: AtomicU32::new(1),
55             device_available,
56             data_flow,
57         });
58 
59         // This is safe if the value passed into `from_raw` is structured in a way where it can
60         // match `IMMNotificationClient`. Since `win_imm_notification_client.cast_to_com_ptr()`
61         // does, this is safe.
62         //
63         // SAFETY: We are passing in a valid COM object that implements `IUnknown` into
64         // `from_raw`.
65         unsafe {
66             ComPtr::from_raw(
67                 Box::into_raw(win_imm_notification_client) as *mut IMMNotificationClient
68             )
69         }
70     }
71 
increment_counter(&self) -> ULONG72     fn increment_counter(&self) -> ULONG {
73         self.ref_count.fetch_add(1, SeqCst) + 1
74     }
75 
decrement_counter(&mut self) -> ULONG76     fn decrement_counter(&mut self) -> ULONG {
77         let old_val = self.ref_count.fetch_sub(1, SeqCst);
78         assert_ne!(
79             old_val, 0,
80             "Attempted to decrement WinIMMNotificationClient ref count when it \
81         is already 0."
82         );
83         old_val - 1
84     }
85 }
86 
87 impl Drop for WinIMMNotificationClient {
drop(&mut self)88     fn drop(&mut self) {
89         info!("IMMNotificationClient is dropped.");
90     }
91 }
92 
93 // TODO(b/274146821): Factor out common IUnknown code between here and `completion_handler.rs`.
94 const IMM_NOTIFICATION_CLIENT_VTBL: &IMMNotificationClientVtbl = {
95     &IMMNotificationClientVtbl {
96         parent: IUnknownVtbl {
97             QueryInterface: {
98                 /// Safe because if `this` is not implemented (fails the RIID check) this function
99                 /// will just return. If it valid, it should be able to safely increment the ref
100                 /// counter and set the pointer `ppv_object`.
query_interface( this: *mut IUnknown, riid: REFIID, ppv_object: *mut *mut c_void, ) -> HRESULT101                 unsafe extern "system" fn query_interface(
102                     this: *mut IUnknown,
103                     riid: REFIID,
104                     ppv_object: *mut *mut c_void,
105                 ) -> HRESULT {
106                     info!("querying ref in IMMNotificationClient.");
107                     if ppv_object.is_null() {
108                         return E_INVALIDARG;
109                     }
110 
111                     *ppv_object = std::ptr::null_mut();
112 
113                     // Check for valid RIID's
114                     if IsEqualGUID(&*riid, &IUnknown::uuidof())
115                         || IsEqualGUID(&*riid, &IMMNotificationClient::uuidof())
116                         || IsEqualGUID(&*riid, &IAgileObject::uuidof())
117                     {
118                         *ppv_object = this as *mut c_void;
119                         (*this).AddRef();
120                         return NOERROR;
121                     }
122                     E_NOINTERFACE
123                 }
124                 query_interface
125             },
126             AddRef: {
127                 /// Unsafe if `this` cannot be casted to `WinIMMNotificationClient`.
128                 ///
129                 /// This is safe because `this` is originally a `WinIMMNotificationClient`.
add_ref(this: *mut IUnknown) -> ULONG130                 unsafe extern "system" fn add_ref(this: *mut IUnknown) -> ULONG {
131                     info!("Adding ref in IMMNotificationClient.");
132                     let win_imm_notification_client = this as *mut WinIMMNotificationClient;
133                     (*win_imm_notification_client).increment_counter()
134                 }
135                 add_ref
136             },
137             Release: {
138                 /// Unsafe if `this` cannot because casted to `WinIMMNotificationClient`. Also
139                 /// would be unsafe if `release` is called more than `add_ref`.
140                 ///
141                 /// This is safe because `this` is
142                 /// originally a `WinIMMNotificationClient` and isn't called
143                 /// more than `add_ref`.
release(this: *mut IUnknown) -> ULONG144                 unsafe extern "system" fn release(this: *mut IUnknown) -> ULONG {
145                     info!("Releasing ref in IMMNotificationClient.");
146                     // Decrementing will free the `this` pointer if it's ref_count becomes 0.
147                     let win_imm_notification_client = this as *mut WinIMMNotificationClient;
148                     let ref_count = (*win_imm_notification_client).decrement_counter();
149                     if ref_count == 0 {
150                         // Delete the pointer
151                         drop(Box::from_raw(this as *mut WinIMMNotificationClient));
152                     }
153                     ref_count
154                 }
155                 release
156             },
157         },
158         OnDeviceStateChanged: on_device_state_change,
159         OnDeviceAdded: on_device_added,
160         OnDeviceRemoved: on_device_removed,
161         OnDefaultDeviceChanged: on_default_device_changed,
162         OnPropertyValueChanged: on_property_value_changed,
163     }
164 };
165 
on_device_state_change( _this: *mut IMMNotificationClient, _pwstr_device_id: LPCWSTR, _dw_new_state: DWORD, ) -> HRESULT166 unsafe extern "system" fn on_device_state_change(
167     _this: *mut IMMNotificationClient,
168     _pwstr_device_id: LPCWSTR,
169     _dw_new_state: DWORD,
170 ) -> HRESULT {
171     info!("IMMNotificationClient: on_device_state_change called");
172     0
173 }
174 
175 /// Indicates that an audio enpoint device has been added. In practice, I have not seen this get
176 /// triggered, even if I add an audio device.
177 ///
178 /// # Safety
179 /// This is safe because this callback does nothing except for logging.
on_device_added( _this: *mut IMMNotificationClient, _pwstr_device_id: LPCWSTR, ) -> HRESULT180 unsafe extern "system" fn on_device_added(
181     _this: *mut IMMNotificationClient,
182     _pwstr_device_id: LPCWSTR,
183 ) -> HRESULT {
184     info!("IMMNotificationClient: on_device_added called");
185     0
186 }
187 
188 /// Indicates that an audio enpoint device has been removed. In practice, I have not seen this get
189 /// triggered, even if I unplug an audio device.
190 ///
191 /// # Safety
192 /// This is safe because this callback does nothing except for logging.
on_device_removed( _this: *mut IMMNotificationClient, _pwstr_device_id: LPCWSTR, ) -> HRESULT193 unsafe extern "system" fn on_device_removed(
194     _this: *mut IMMNotificationClient,
195     _pwstr_device_id: LPCWSTR,
196 ) -> HRESULT {
197     info!("IMMNotificationClient: on_device_removed called");
198     0
199 }
200 
201 /// Indicates that the default device has changed. In practice, this callback seemed reliable to
202 /// tell us when a new audio device has been added when no devices were previously present.
203 ///
204 /// # Safety
205 /// Safe because we know `IMMNotificationClient` was originally a `WinIMMNotificationClient`,
206 /// so we can cast safely.
on_default_device_changed( this: *mut IMMNotificationClient, flow: EDataFlow, _role: ERole, _pwstr_default_device_id: LPCWSTR, ) -> HRESULT207 unsafe extern "system" fn on_default_device_changed(
208     this: *mut IMMNotificationClient,
209     flow: EDataFlow,
210     _role: ERole,
211     _pwstr_default_device_id: LPCWSTR,
212 ) -> HRESULT {
213     info!("IMMNotificationClient: on_default_device_changed called");
214     let win = &*(this as *mut WinIMMNotificationClient);
215     if flow == win.data_flow {
216         base::info!("New device found");
217         win.device_available.store(true, SeqCst);
218     }
219     0
220 }
221 
222 /// Indicates that a property in an audio endpoint device has changed. In practice, this callback
223 /// gets spammed a lot and the information provided isn't useful.
224 ///
225 /// # Safety
226 /// This is safe because this callback does nothing.
on_property_value_changed( _this: *mut IMMNotificationClient, _pwstr_device_id: LPCWSTR, _key: PROPERTYKEY, ) -> HRESULT227 unsafe extern "system" fn on_property_value_changed(
228     _this: *mut IMMNotificationClient,
229     _pwstr_device_id: LPCWSTR,
230     _key: PROPERTYKEY,
231 ) -> HRESULT {
232     0
233 }
234 
235 /// The following tests the correctness of the COM object implementation. It won't test for
236 /// notifications of new devices.
237 #[cfg(test)]
238 mod test {
239     use winapi::um::mmdeviceapi::eCapture;
240     use winapi::um::mmdeviceapi::eRender;
241     use winapi::um::mmdeviceapi::IMMDeviceCollection;
242 
243     use super::*;
244 
245     #[test]
test_query_interface_valid()246     fn test_query_interface_valid() {
247         let notification_client =
248             WinIMMNotificationClient::create_com_ptr(Arc::new(AtomicBool::new(false)), eRender);
249         let valid_ref_iid = IUnknown::uuidof();
250         let mut ppv_object: *mut c_void = std::ptr::null_mut();
251 
252         // Calling `QueryInterface`
253         // SAFETY: notification_client has a valid lpVtbl pointer
254         let res = unsafe {
255             ((*notification_client.lpVtbl).parent.QueryInterface)(
256                 notification_client.as_raw() as *mut IUnknown,
257                 &valid_ref_iid,
258                 &mut ppv_object,
259             )
260         };
261         assert_eq!(res, NOERROR);
262 
263         // Release the reference from `QueryInteface` by calling `Release`
264         release(&notification_client);
265 
266         let valid_ref_iid = IMMNotificationClient::uuidof();
267         // SAFETY: notification_client has a valid lpVtbl pointer
268         let res = unsafe {
269             ((*notification_client.lpVtbl).parent.QueryInterface)(
270                 notification_client.as_raw() as *mut IUnknown,
271                 &valid_ref_iid,
272                 &mut ppv_object,
273             )
274         };
275         assert_eq!(res, NOERROR);
276 
277         release(&notification_client);
278 
279         let valid_ref_iid = IAgileObject::uuidof();
280         // SAFETY: notification_client has a valid lpVtbl pointer
281         let res = unsafe {
282             ((*notification_client.lpVtbl).parent.QueryInterface)(
283                 notification_client.as_raw() as *mut IUnknown,
284                 &valid_ref_iid,
285                 &mut ppv_object,
286             )
287         };
288         release(&notification_client);
289         assert_eq!(res, NOERROR);
290     }
291 
292     #[test]
test_query_interface_invalid()293     fn test_query_interface_invalid() {
294         let notification_client =
295             WinIMMNotificationClient::create_com_ptr(Arc::new(AtomicBool::new(false)), eRender);
296         let invalid_ref_iid = IMMDeviceCollection::uuidof();
297         let mut ppv_object: *mut c_void = std::ptr::null_mut();
298 
299         // Call `QueryInterface`
300         // SAFETY: notification_client has a valid lpVtbl pointer
301         let res = unsafe {
302             ((*notification_client.lpVtbl).parent.QueryInterface)(
303                 notification_client.as_raw() as *mut IUnknown,
304                 &invalid_ref_iid,
305                 &mut ppv_object,
306             )
307         };
308         assert_eq!(res, E_NOINTERFACE)
309     }
310 
311     #[test]
test_release()312     fn test_release() {
313         // ref_count = 1
314         let notification_client =
315             WinIMMNotificationClient::create_com_ptr(Arc::new(AtomicBool::new(false)), eCapture);
316         // ref_count = 2
317         let ref_count = add_ref(&notification_client);
318         assert_eq!(ref_count, 2);
319         // ref_count = 1
320         let ref_count = release(&notification_client);
321         assert_eq!(ref_count, 1);
322         // ref_count = 0 since ComPtr drops
323     }
324 
release(notification_client: &ComPtr<IMMNotificationClient>) -> ULONG325     fn release(notification_client: &ComPtr<IMMNotificationClient>) -> ULONG {
326         // SAFETY: notification_client has a valid lpVtbl pointer
327         unsafe {
328             ((*notification_client.lpVtbl).parent.Release)(
329                 notification_client.as_raw() as *mut IUnknown
330             )
331         }
332     }
333 
add_ref(notification_client: &ComPtr<IMMNotificationClient>) -> ULONG334     fn add_ref(notification_client: &ComPtr<IMMNotificationClient>) -> ULONG {
335         // SAFETY: notification_client has a valid lpVtbl pointer
336         unsafe {
337             ((*notification_client.lpVtbl).parent.AddRef)(
338                 notification_client.as_raw() as *mut IUnknown
339             )
340         }
341     }
342 }
343