1 // Copyright 2023 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::sync::atomic::AtomicBool;
6 use std::sync::atomic::AtomicU32;
7 use std::sync::atomic::Ordering::SeqCst;
8 use std::sync::Arc;
9
10 use base::info;
11 use libc::c_void;
12 use winapi::shared::guiddef::IsEqualGUID;
13 use winapi::shared::guiddef::REFIID;
14 use winapi::shared::minwindef::DWORD;
15 use winapi::shared::minwindef::ULONG;
16 use winapi::shared::winerror::E_INVALIDARG;
17 use winapi::shared::winerror::E_NOINTERFACE;
18 use winapi::shared::winerror::NOERROR;
19 use winapi::shared::wtypes::PROPERTYKEY;
20 use winapi::um::mmdeviceapi::EDataFlow;
21 use winapi::um::mmdeviceapi::ERole;
22 use winapi::um::mmdeviceapi::IMMNotificationClient;
23 use winapi::um::mmdeviceapi::IMMNotificationClientVtbl;
24 use winapi::um::objidlbase::IAgileObject;
25 use winapi::um::unknwnbase::IUnknown;
26 use winapi::um::unknwnbase::IUnknownVtbl;
27 use winapi::um::winnt::HRESULT;
28 use winapi::um::winnt::LPCWSTR;
29 use winapi::Interface;
30 use wio::com::ComPtr;
31
32 /// This device notification client will be used to notify win_audio when a new audio device is
33 /// available. This notification client will only be registered when there are
34 /// no audio devices detected.
35 #[repr(C)]
36 pub(crate) struct WinIMMNotificationClient {
37 pub lp_vtbl: &'static IMMNotificationClientVtbl,
38 ref_count: AtomicU32,
39 // Shared with `WinAudioRenderer`. This will used in `next_playback_buffer` only when
40 // `NoopStream` is being used. When this is set the `true`, `WinAudioRenderer` will attempt
41 // to create a new `DeviceRenderer`.
42 device_available: Arc<AtomicBool>,
43 data_flow: EDataFlow,
44 }
45
46 impl WinIMMNotificationClient {
47 /// The ComPtr is a `WinIMMNotificationClient` casted as an `IMMNotificationClient`.
create_com_ptr( device_available: Arc<AtomicBool>, data_flow: EDataFlow, ) -> ComPtr<IMMNotificationClient>48 pub(crate) fn create_com_ptr(
49 device_available: Arc<AtomicBool>,
50 data_flow: EDataFlow,
51 ) -> ComPtr<IMMNotificationClient> {
52 let win_imm_notification_client = Box::new(WinIMMNotificationClient {
53 lp_vtbl: IMM_NOTIFICATION_CLIENT_VTBL,
54 ref_count: AtomicU32::new(1),
55 device_available,
56 data_flow,
57 });
58
59 // This is safe if the value passed into `from_raw` is structured in a way where it can
60 // match `IMMNotificationClient`. Since `win_imm_notification_client.cast_to_com_ptr()`
61 // does, this is safe.
62 //
63 // SAFETY: We are passing in a valid COM object that implements `IUnknown` into
64 // `from_raw`.
65 unsafe {
66 ComPtr::from_raw(
67 Box::into_raw(win_imm_notification_client) as *mut IMMNotificationClient
68 )
69 }
70 }
71
increment_counter(&self) -> ULONG72 fn increment_counter(&self) -> ULONG {
73 self.ref_count.fetch_add(1, SeqCst) + 1
74 }
75
decrement_counter(&mut self) -> ULONG76 fn decrement_counter(&mut self) -> ULONG {
77 let old_val = self.ref_count.fetch_sub(1, SeqCst);
78 assert_ne!(
79 old_val, 0,
80 "Attempted to decrement WinIMMNotificationClient ref count when it \
81 is already 0."
82 );
83 old_val - 1
84 }
85 }
86
87 impl Drop for WinIMMNotificationClient {
drop(&mut self)88 fn drop(&mut self) {
89 info!("IMMNotificationClient is dropped.");
90 }
91 }
92
93 // TODO(b/274146821): Factor out common IUnknown code between here and `completion_handler.rs`.
94 const IMM_NOTIFICATION_CLIENT_VTBL: &IMMNotificationClientVtbl = {
95 &IMMNotificationClientVtbl {
96 parent: IUnknownVtbl {
97 QueryInterface: {
98 /// Safe because if `this` is not implemented (fails the RIID check) this function
99 /// will just return. If it valid, it should be able to safely increment the ref
100 /// counter and set the pointer `ppv_object`.
query_interface( this: *mut IUnknown, riid: REFIID, ppv_object: *mut *mut c_void, ) -> HRESULT101 unsafe extern "system" fn query_interface(
102 this: *mut IUnknown,
103 riid: REFIID,
104 ppv_object: *mut *mut c_void,
105 ) -> HRESULT {
106 info!("querying ref in IMMNotificationClient.");
107 if ppv_object.is_null() {
108 return E_INVALIDARG;
109 }
110
111 *ppv_object = std::ptr::null_mut();
112
113 // Check for valid RIID's
114 if IsEqualGUID(&*riid, &IUnknown::uuidof())
115 || IsEqualGUID(&*riid, &IMMNotificationClient::uuidof())
116 || IsEqualGUID(&*riid, &IAgileObject::uuidof())
117 {
118 *ppv_object = this as *mut c_void;
119 (*this).AddRef();
120 return NOERROR;
121 }
122 E_NOINTERFACE
123 }
124 query_interface
125 },
126 AddRef: {
127 /// Unsafe if `this` cannot be casted to `WinIMMNotificationClient`.
128 ///
129 /// This is safe because `this` is originally a `WinIMMNotificationClient`.
add_ref(this: *mut IUnknown) -> ULONG130 unsafe extern "system" fn add_ref(this: *mut IUnknown) -> ULONG {
131 info!("Adding ref in IMMNotificationClient.");
132 let win_imm_notification_client = this as *mut WinIMMNotificationClient;
133 (*win_imm_notification_client).increment_counter()
134 }
135 add_ref
136 },
137 Release: {
138 /// Unsafe if `this` cannot because casted to `WinIMMNotificationClient`. Also
139 /// would be unsafe if `release` is called more than `add_ref`.
140 ///
141 /// This is safe because `this` is
142 /// originally a `WinIMMNotificationClient` and isn't called
143 /// more than `add_ref`.
release(this: *mut IUnknown) -> ULONG144 unsafe extern "system" fn release(this: *mut IUnknown) -> ULONG {
145 info!("Releasing ref in IMMNotificationClient.");
146 // Decrementing will free the `this` pointer if it's ref_count becomes 0.
147 let win_imm_notification_client = this as *mut WinIMMNotificationClient;
148 let ref_count = (*win_imm_notification_client).decrement_counter();
149 if ref_count == 0 {
150 // Delete the pointer
151 drop(Box::from_raw(this as *mut WinIMMNotificationClient));
152 }
153 ref_count
154 }
155 release
156 },
157 },
158 OnDeviceStateChanged: on_device_state_change,
159 OnDeviceAdded: on_device_added,
160 OnDeviceRemoved: on_device_removed,
161 OnDefaultDeviceChanged: on_default_device_changed,
162 OnPropertyValueChanged: on_property_value_changed,
163 }
164 };
165
on_device_state_change( _this: *mut IMMNotificationClient, _pwstr_device_id: LPCWSTR, _dw_new_state: DWORD, ) -> HRESULT166 unsafe extern "system" fn on_device_state_change(
167 _this: *mut IMMNotificationClient,
168 _pwstr_device_id: LPCWSTR,
169 _dw_new_state: DWORD,
170 ) -> HRESULT {
171 info!("IMMNotificationClient: on_device_state_change called");
172 0
173 }
174
175 /// Indicates that an audio enpoint device has been added. In practice, I have not seen this get
176 /// triggered, even if I add an audio device.
177 ///
178 /// # Safety
179 /// This is safe because this callback does nothing except for logging.
on_device_added( _this: *mut IMMNotificationClient, _pwstr_device_id: LPCWSTR, ) -> HRESULT180 unsafe extern "system" fn on_device_added(
181 _this: *mut IMMNotificationClient,
182 _pwstr_device_id: LPCWSTR,
183 ) -> HRESULT {
184 info!("IMMNotificationClient: on_device_added called");
185 0
186 }
187
188 /// Indicates that an audio enpoint device has been removed. In practice, I have not seen this get
189 /// triggered, even if I unplug an audio device.
190 ///
191 /// # Safety
192 /// This is safe because this callback does nothing except for logging.
on_device_removed( _this: *mut IMMNotificationClient, _pwstr_device_id: LPCWSTR, ) -> HRESULT193 unsafe extern "system" fn on_device_removed(
194 _this: *mut IMMNotificationClient,
195 _pwstr_device_id: LPCWSTR,
196 ) -> HRESULT {
197 info!("IMMNotificationClient: on_device_removed called");
198 0
199 }
200
201 /// Indicates that the default device has changed. In practice, this callback seemed reliable to
202 /// tell us when a new audio device has been added when no devices were previously present.
203 ///
204 /// # Safety
205 /// Safe because we know `IMMNotificationClient` was originally a `WinIMMNotificationClient`,
206 /// so we can cast safely.
on_default_device_changed( this: *mut IMMNotificationClient, flow: EDataFlow, _role: ERole, _pwstr_default_device_id: LPCWSTR, ) -> HRESULT207 unsafe extern "system" fn on_default_device_changed(
208 this: *mut IMMNotificationClient,
209 flow: EDataFlow,
210 _role: ERole,
211 _pwstr_default_device_id: LPCWSTR,
212 ) -> HRESULT {
213 info!("IMMNotificationClient: on_default_device_changed called");
214 let win = &*(this as *mut WinIMMNotificationClient);
215 if flow == win.data_flow {
216 base::info!("New device found");
217 win.device_available.store(true, SeqCst);
218 }
219 0
220 }
221
222 /// Indicates that a property in an audio endpoint device has changed. In practice, this callback
223 /// gets spammed a lot and the information provided isn't useful.
224 ///
225 /// # Safety
226 /// This is safe because this callback does nothing.
on_property_value_changed( _this: *mut IMMNotificationClient, _pwstr_device_id: LPCWSTR, _key: PROPERTYKEY, ) -> HRESULT227 unsafe extern "system" fn on_property_value_changed(
228 _this: *mut IMMNotificationClient,
229 _pwstr_device_id: LPCWSTR,
230 _key: PROPERTYKEY,
231 ) -> HRESULT {
232 0
233 }
234
235 /// The following tests the correctness of the COM object implementation. It won't test for
236 /// notifications of new devices.
237 #[cfg(test)]
238 mod test {
239 use winapi::um::mmdeviceapi::eCapture;
240 use winapi::um::mmdeviceapi::eRender;
241 use winapi::um::mmdeviceapi::IMMDeviceCollection;
242
243 use super::*;
244
245 #[test]
test_query_interface_valid()246 fn test_query_interface_valid() {
247 let notification_client =
248 WinIMMNotificationClient::create_com_ptr(Arc::new(AtomicBool::new(false)), eRender);
249 let valid_ref_iid = IUnknown::uuidof();
250 let mut ppv_object: *mut c_void = std::ptr::null_mut();
251
252 // Calling `QueryInterface`
253 // SAFETY: notification_client has a valid lpVtbl pointer
254 let res = unsafe {
255 ((*notification_client.lpVtbl).parent.QueryInterface)(
256 notification_client.as_raw() as *mut IUnknown,
257 &valid_ref_iid,
258 &mut ppv_object,
259 )
260 };
261 assert_eq!(res, NOERROR);
262
263 // Release the reference from `QueryInteface` by calling `Release`
264 release(¬ification_client);
265
266 let valid_ref_iid = IMMNotificationClient::uuidof();
267 // SAFETY: notification_client has a valid lpVtbl pointer
268 let res = unsafe {
269 ((*notification_client.lpVtbl).parent.QueryInterface)(
270 notification_client.as_raw() as *mut IUnknown,
271 &valid_ref_iid,
272 &mut ppv_object,
273 )
274 };
275 assert_eq!(res, NOERROR);
276
277 release(¬ification_client);
278
279 let valid_ref_iid = IAgileObject::uuidof();
280 // SAFETY: notification_client has a valid lpVtbl pointer
281 let res = unsafe {
282 ((*notification_client.lpVtbl).parent.QueryInterface)(
283 notification_client.as_raw() as *mut IUnknown,
284 &valid_ref_iid,
285 &mut ppv_object,
286 )
287 };
288 release(¬ification_client);
289 assert_eq!(res, NOERROR);
290 }
291
292 #[test]
test_query_interface_invalid()293 fn test_query_interface_invalid() {
294 let notification_client =
295 WinIMMNotificationClient::create_com_ptr(Arc::new(AtomicBool::new(false)), eRender);
296 let invalid_ref_iid = IMMDeviceCollection::uuidof();
297 let mut ppv_object: *mut c_void = std::ptr::null_mut();
298
299 // Call `QueryInterface`
300 // SAFETY: notification_client has a valid lpVtbl pointer
301 let res = unsafe {
302 ((*notification_client.lpVtbl).parent.QueryInterface)(
303 notification_client.as_raw() as *mut IUnknown,
304 &invalid_ref_iid,
305 &mut ppv_object,
306 )
307 };
308 assert_eq!(res, E_NOINTERFACE)
309 }
310
311 #[test]
test_release()312 fn test_release() {
313 // ref_count = 1
314 let notification_client =
315 WinIMMNotificationClient::create_com_ptr(Arc::new(AtomicBool::new(false)), eCapture);
316 // ref_count = 2
317 let ref_count = add_ref(¬ification_client);
318 assert_eq!(ref_count, 2);
319 // ref_count = 1
320 let ref_count = release(¬ification_client);
321 assert_eq!(ref_count, 1);
322 // ref_count = 0 since ComPtr drops
323 }
324
release(notification_client: &ComPtr<IMMNotificationClient>) -> ULONG325 fn release(notification_client: &ComPtr<IMMNotificationClient>) -> ULONG {
326 // SAFETY: notification_client has a valid lpVtbl pointer
327 unsafe {
328 ((*notification_client.lpVtbl).parent.Release)(
329 notification_client.as_raw() as *mut IUnknown
330 )
331 }
332 }
333
add_ref(notification_client: &ComPtr<IMMNotificationClient>) -> ULONG334 fn add_ref(notification_client: &ComPtr<IMMNotificationClient>) -> ULONG {
335 // SAFETY: notification_client has a valid lpVtbl pointer
336 unsafe {
337 ((*notification_client.lpVtbl).parent.AddRef)(
338 notification_client.as_raw() as *mut IUnknown
339 )
340 }
341 }
342 }
343