1# Copyright 2021-2023 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import os
20import functools
21import pytest
22import logging
23
24from bumble import device
25from bumble.hci import CodecID, CodingFormat
26from bumble.profiles.ascs import (
27    AudioStreamControlService,
28    AudioStreamControlServiceProxy,
29    AseStateMachine,
30    ASE_Operation,
31    ASE_Config_Codec,
32    ASE_Config_QOS,
33    ASE_Disable,
34    ASE_Enable,
35    ASE_Receiver_Start_Ready,
36    ASE_Receiver_Stop_Ready,
37    ASE_Release,
38    ASE_Update_Metadata,
39)
40from bumble.profiles.bap import (
41    AudioLocation,
42    SupportedFrameDuration,
43    SupportedSamplingFrequency,
44    SamplingFrequency,
45    FrameDuration,
46    CodecSpecificCapabilities,
47    CodecSpecificConfiguration,
48    ContextType,
49)
50from bumble.profiles.pacs import (
51    PacRecord,
52    PublishedAudioCapabilitiesService,
53    PublishedAudioCapabilitiesServiceProxy,
54)
55from bumble.profiles.le_audio import Metadata
56from tests.test_utils import TwoDevices
57
58
59# -----------------------------------------------------------------------------
60# Logging
61# -----------------------------------------------------------------------------
62logger = logging.getLogger(__name__)
63
64
65# -----------------------------------------------------------------------------
66def basic_check(operation: ASE_Operation):
67    serialized = bytes(operation)
68    parsed = ASE_Operation.from_bytes(serialized)
69    assert bytes(parsed) == serialized
70
71
72# -----------------------------------------------------------------------------
73def test_codec_specific_capabilities() -> None:
74    SAMPLE_FREQUENCY = SupportedSamplingFrequency.FREQ_16000
75    FRAME_SURATION = SupportedFrameDuration.DURATION_10000_US_SUPPORTED
76    AUDIO_CHANNEL_COUNTS = [1]
77    cap = CodecSpecificCapabilities(
78        supported_sampling_frequencies=SAMPLE_FREQUENCY,
79        supported_frame_durations=FRAME_SURATION,
80        supported_audio_channel_count=AUDIO_CHANNEL_COUNTS,
81        min_octets_per_codec_frame=40,
82        max_octets_per_codec_frame=40,
83        supported_max_codec_frames_per_sdu=1,
84    )
85    assert CodecSpecificCapabilities.from_bytes(bytes(cap)) == cap
86
87
88# -----------------------------------------------------------------------------
89def test_pac_record() -> None:
90    SAMPLE_FREQUENCY = SupportedSamplingFrequency.FREQ_16000
91    FRAME_SURATION = SupportedFrameDuration.DURATION_10000_US_SUPPORTED
92    AUDIO_CHANNEL_COUNTS = [1]
93    cap = CodecSpecificCapabilities(
94        supported_sampling_frequencies=SAMPLE_FREQUENCY,
95        supported_frame_durations=FRAME_SURATION,
96        supported_audio_channel_count=AUDIO_CHANNEL_COUNTS,
97        min_octets_per_codec_frame=40,
98        max_octets_per_codec_frame=40,
99        supported_max_codec_frames_per_sdu=1,
100    )
101
102    pac_record = PacRecord(
103        coding_format=CodingFormat(CodecID.LC3),
104        codec_specific_capabilities=cap,
105        metadata=Metadata([Metadata.Entry(tag=Metadata.Tag.VENDOR_SPECIFIC, data=b'')]),
106    )
107    assert PacRecord.from_bytes(bytes(pac_record)) == pac_record
108
109
110# -----------------------------------------------------------------------------
111def test_vendor_specific_pac_record() -> None:
112    # Vendor-Specific codec, Google, ID=0xFFFF. No capabilities and metadata.
113    RAW_DATA = bytes.fromhex('ffe000ffff0000')
114    assert bytes(PacRecord.from_bytes(RAW_DATA)) == RAW_DATA
115
116
117# -----------------------------------------------------------------------------
118def test_ASE_Config_Codec() -> None:
119    operation = ASE_Config_Codec(
120        ase_id=[1, 2],
121        target_latency=[3, 4],
122        target_phy=[5, 6],
123        codec_id=[CodingFormat(CodecID.LC3), CodingFormat(CodecID.LC3)],
124        codec_specific_configuration=[b'foo', b'bar'],
125    )
126    basic_check(operation)
127
128
129# -----------------------------------------------------------------------------
130def test_ASE_Config_QOS() -> None:
131    operation = ASE_Config_QOS(
132        ase_id=[1, 2],
133        cig_id=[1, 2],
134        cis_id=[3, 4],
135        sdu_interval=[5, 6],
136        framing=[0, 1],
137        phy=[2, 3],
138        max_sdu=[4, 5],
139        retransmission_number=[6, 7],
140        max_transport_latency=[8, 9],
141        presentation_delay=[10, 11],
142    )
143    basic_check(operation)
144
145
146# -----------------------------------------------------------------------------
147def test_ASE_Enable() -> None:
148    operation = ASE_Enable(
149        ase_id=[1, 2],
150        metadata=[b'', b''],
151    )
152    basic_check(operation)
153
154
155# -----------------------------------------------------------------------------
156def test_ASE_Update_Metadata() -> None:
157    operation = ASE_Update_Metadata(
158        ase_id=[1, 2],
159        metadata=[b'', b''],
160    )
161    basic_check(operation)
162
163
164# -----------------------------------------------------------------------------
165def test_ASE_Disable() -> None:
166    operation = ASE_Disable(ase_id=[1, 2])
167    basic_check(operation)
168
169
170# -----------------------------------------------------------------------------
171def test_ASE_Release() -> None:
172    operation = ASE_Release(ase_id=[1, 2])
173    basic_check(operation)
174
175
176# -----------------------------------------------------------------------------
177def test_ASE_Receiver_Start_Ready() -> None:
178    operation = ASE_Receiver_Start_Ready(ase_id=[1, 2])
179    basic_check(operation)
180
181
182# -----------------------------------------------------------------------------
183def test_ASE_Receiver_Stop_Ready() -> None:
184    operation = ASE_Receiver_Stop_Ready(ase_id=[1, 2])
185    basic_check(operation)
186
187
188# -----------------------------------------------------------------------------
189def test_codec_specific_configuration() -> None:
190    SAMPLE_FREQUENCY = SamplingFrequency.FREQ_16000
191    FRAME_SURATION = FrameDuration.DURATION_10000_US
192    AUDIO_LOCATION = AudioLocation.FRONT_LEFT
193    config = CodecSpecificConfiguration(
194        sampling_frequency=SAMPLE_FREQUENCY,
195        frame_duration=FRAME_SURATION,
196        audio_channel_allocation=AUDIO_LOCATION,
197        octets_per_codec_frame=60,
198        codec_frames_per_sdu=1,
199    )
200    assert CodecSpecificConfiguration.from_bytes(bytes(config)) == config
201
202
203# -----------------------------------------------------------------------------
204@pytest.mark.asyncio
205async def test_pacs():
206    devices = TwoDevices()
207    devices[0].add_service(
208        PublishedAudioCapabilitiesService(
209            supported_sink_context=ContextType.MEDIA,
210            available_sink_context=ContextType.MEDIA,
211            supported_source_context=0,
212            available_source_context=0,
213            sink_pac=[
214                # Codec Capability Setting 16_2
215                PacRecord(
216                    coding_format=CodingFormat(CodecID.LC3),
217                    codec_specific_capabilities=CodecSpecificCapabilities(
218                        supported_sampling_frequencies=(
219                            SupportedSamplingFrequency.FREQ_16000
220                        ),
221                        supported_frame_durations=(
222                            SupportedFrameDuration.DURATION_10000_US_SUPPORTED
223                        ),
224                        supported_audio_channel_count=[1],
225                        min_octets_per_codec_frame=40,
226                        max_octets_per_codec_frame=40,
227                        supported_max_codec_frames_per_sdu=1,
228                    ),
229                ),
230                # Codec Capability Setting 24_2
231                PacRecord(
232                    coding_format=CodingFormat(CodecID.LC3),
233                    codec_specific_capabilities=CodecSpecificCapabilities(
234                        supported_sampling_frequencies=(
235                            SupportedSamplingFrequency.FREQ_24000
236                        ),
237                        supported_frame_durations=(
238                            SupportedFrameDuration.DURATION_10000_US_SUPPORTED
239                        ),
240                        supported_audio_channel_count=[1],
241                        min_octets_per_codec_frame=60,
242                        max_octets_per_codec_frame=60,
243                        supported_max_codec_frames_per_sdu=1,
244                    ),
245                ),
246            ],
247            sink_audio_locations=AudioLocation.FRONT_LEFT | AudioLocation.FRONT_RIGHT,
248        )
249    )
250
251    await devices.setup_connection()
252    peer = device.Peer(devices.connections[1])
253    pacs_client = await peer.discover_service_and_create_proxy(
254        PublishedAudioCapabilitiesServiceProxy
255    )
256
257
258# -----------------------------------------------------------------------------
259@pytest.mark.asyncio
260async def test_ascs():
261    devices = TwoDevices()
262    devices[0].add_service(
263        AudioStreamControlService(device=devices[0], sink_ase_id=[1, 2])
264    )
265
266    await devices.setup_connection()
267    peer = device.Peer(devices.connections[1])
268    ascs_client = await peer.discover_service_and_create_proxy(
269        AudioStreamControlServiceProxy
270    )
271
272    notifications = {1: asyncio.Queue(), 2: asyncio.Queue()}
273
274    def on_notification(data: bytes, ase_id: int):
275        notifications[ase_id].put_nowait(data)
276
277    # Should be idle
278    assert await ascs_client.sink_ase[0].read_value() == bytes(
279        [1, AseStateMachine.State.IDLE]
280    )
281    assert await ascs_client.sink_ase[1].read_value() == bytes(
282        [2, AseStateMachine.State.IDLE]
283    )
284
285    # Subscribe
286    await ascs_client.sink_ase[0].subscribe(
287        functools.partial(on_notification, ase_id=1)
288    )
289    await ascs_client.sink_ase[1].subscribe(
290        functools.partial(on_notification, ase_id=2)
291    )
292
293    # Config Codec
294    config = CodecSpecificConfiguration(
295        sampling_frequency=SamplingFrequency.FREQ_48000,
296        frame_duration=FrameDuration.DURATION_10000_US,
297        audio_channel_allocation=AudioLocation.FRONT_LEFT,
298        octets_per_codec_frame=120,
299        codec_frames_per_sdu=1,
300    )
301    await ascs_client.ase_control_point.write_value(
302        ASE_Config_Codec(
303            ase_id=[1, 2],
304            target_latency=[3, 4],
305            target_phy=[5, 6],
306            codec_id=[CodingFormat(CodecID.LC3), CodingFormat(CodecID.LC3)],
307            codec_specific_configuration=[config, config],
308        )
309    )
310    assert (await notifications[1].get())[:2] == bytes(
311        [1, AseStateMachine.State.CODEC_CONFIGURED]
312    )
313    assert (await notifications[2].get())[:2] == bytes(
314        [2, AseStateMachine.State.CODEC_CONFIGURED]
315    )
316
317    # Config QOS
318    await ascs_client.ase_control_point.write_value(
319        ASE_Config_QOS(
320            ase_id=[1, 2],
321            cig_id=[1, 2],
322            cis_id=[3, 4],
323            sdu_interval=[5, 6],
324            framing=[0, 1],
325            phy=[2, 3],
326            max_sdu=[4, 5],
327            retransmission_number=[6, 7],
328            max_transport_latency=[8, 9],
329            presentation_delay=[10, 11],
330        )
331    )
332    assert (await notifications[1].get())[:2] == bytes(
333        [1, AseStateMachine.State.QOS_CONFIGURED]
334    )
335    assert (await notifications[2].get())[:2] == bytes(
336        [2, AseStateMachine.State.QOS_CONFIGURED]
337    )
338
339    # Enable
340    await ascs_client.ase_control_point.write_value(
341        ASE_Enable(
342            ase_id=[1, 2],
343            metadata=[b'foo', b'bar'],
344        )
345    )
346    assert (await notifications[1].get())[:2] == bytes(
347        [1, AseStateMachine.State.ENABLING]
348    )
349    assert (await notifications[2].get())[:2] == bytes(
350        [2, AseStateMachine.State.ENABLING]
351    )
352
353    # CIS establishment
354    devices[0].emit(
355        'cis_establishment',
356        device.CisLink(
357            device=devices[0],
358            acl_connection=devices.connections[0],
359            handle=5,
360            cis_id=3,
361            cig_id=1,
362        ),
363    )
364    devices[0].emit(
365        'cis_establishment',
366        device.CisLink(
367            device=devices[0],
368            acl_connection=devices.connections[0],
369            handle=6,
370            cis_id=4,
371            cig_id=2,
372        ),
373    )
374    assert (await notifications[1].get())[:2] == bytes(
375        [1, AseStateMachine.State.STREAMING]
376    )
377    assert (await notifications[2].get())[:2] == bytes(
378        [2, AseStateMachine.State.STREAMING]
379    )
380
381    # Release
382    await ascs_client.ase_control_point.write_value(
383        ASE_Release(
384            ase_id=[1, 2],
385            metadata=[b'foo', b'bar'],
386        )
387    )
388    assert (await notifications[1].get())[:2] == bytes(
389        [1, AseStateMachine.State.RELEASING]
390    )
391    assert (await notifications[2].get())[:2] == bytes(
392        [2, AseStateMachine.State.RELEASING]
393    )
394    assert (await notifications[1].get())[:2] == bytes([1, AseStateMachine.State.IDLE])
395    assert (await notifications[2].get())[:2] == bytes([2, AseStateMachine.State.IDLE])
396
397    await asyncio.sleep(0.001)
398
399
400# -----------------------------------------------------------------------------
401async def run():
402    await test_pacs()
403
404
405# -----------------------------------------------------------------------------
406if __name__ == '__main__':
407    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
408    asyncio.run(run())
409