1# Copyright 2024 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19import asyncio
20import contextlib
21import dataclasses
22import logging
23import os
24from typing import cast, Any, AsyncGenerator, Coroutine, Dict, Optional, Tuple
25
26import click
27import pyee
28
29from bumble.colors import color
30import bumble.company_ids
31import bumble.core
32import bumble.device
33import bumble.gatt
34import bumble.hci
35import bumble.profiles.bap
36import bumble.profiles.bass
37import bumble.profiles.pbp
38import bumble.transport
39import bumble.utils
40
41
42# -----------------------------------------------------------------------------
43# Logging
44# -----------------------------------------------------------------------------
45logger = logging.getLogger(__name__)
46
47
48# -----------------------------------------------------------------------------
49# Constants
50# -----------------------------------------------------------------------------
51AURACAST_DEFAULT_DEVICE_NAME = 'Bumble Auracast'
52AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address('F0:F1:F2:F3:F4:F5')
53AURACAST_DEFAULT_SYNC_TIMEOUT = 5.0
54AURACAST_DEFAULT_ATT_MTU = 256
55
56
57# -----------------------------------------------------------------------------
58# Scan For Broadcasts
59# -----------------------------------------------------------------------------
60class BroadcastScanner(pyee.EventEmitter):
61    @dataclasses.dataclass
62    class Broadcast(pyee.EventEmitter):
63        name: str
64        sync: bumble.device.PeriodicAdvertisingSync
65        rssi: int = 0
66        public_broadcast_announcement: Optional[
67            bumble.profiles.pbp.PublicBroadcastAnnouncement
68        ] = None
69        broadcast_audio_announcement: Optional[
70            bumble.profiles.bap.BroadcastAudioAnnouncement
71        ] = None
72        basic_audio_announcement: Optional[
73            bumble.profiles.bap.BasicAudioAnnouncement
74        ] = None
75        appearance: Optional[bumble.core.Appearance] = None
76        biginfo: Optional[bumble.device.BIGInfoAdvertisement] = None
77        manufacturer_data: Optional[Tuple[str, bytes]] = None
78
79        def __post_init__(self) -> None:
80            super().__init__()
81            self.sync.on('establishment', self.on_sync_establishment)
82            self.sync.on('loss', self.on_sync_loss)
83            self.sync.on('periodic_advertisement', self.on_periodic_advertisement)
84            self.sync.on('biginfo_advertisement', self.on_biginfo_advertisement)
85
86        def update(self, advertisement: bumble.device.Advertisement) -> None:
87            self.rssi = advertisement.rssi
88            for service_data in advertisement.data.get_all(
89                bumble.core.AdvertisingData.SERVICE_DATA
90            ):
91                assert isinstance(service_data, tuple)
92                service_uuid, data = service_data
93                assert isinstance(data, bytes)
94
95                if (
96                    service_uuid
97                    == bumble.gatt.GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE
98                ):
99                    self.public_broadcast_announcement = (
100                        bumble.profiles.pbp.PublicBroadcastAnnouncement.from_bytes(data)
101                    )
102                    continue
103
104                if (
105                    service_uuid
106                    == bumble.gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
107                ):
108                    self.broadcast_audio_announcement = (
109                        bumble.profiles.bap.BroadcastAudioAnnouncement.from_bytes(data)
110                    )
111                    continue
112
113            self.appearance = advertisement.data.get(  # type: ignore[assignment]
114                bumble.core.AdvertisingData.APPEARANCE
115            )
116
117            if manufacturer_data := advertisement.data.get(
118                bumble.core.AdvertisingData.MANUFACTURER_SPECIFIC_DATA
119            ):
120                assert isinstance(manufacturer_data, tuple)
121                company_id = cast(int, manufacturer_data[0])
122                data = cast(bytes, manufacturer_data[1])
123                self.manufacturer_data = (
124                    bumble.company_ids.COMPANY_IDENTIFIERS.get(
125                        company_id, f'0x{company_id:04X}'
126                    ),
127                    data,
128                )
129
130            self.emit('update')
131
132        def print(self) -> None:
133            print(
134                color('Broadcast:', 'yellow'),
135                self.sync.advertiser_address,
136                color(self.sync.state.name, 'green'),
137            )
138            print(f'  {color("Name", "cyan")}:         {self.name}')
139            if self.appearance:
140                print(f'  {color("Appearance", "cyan")}:   {str(self.appearance)}')
141            print(f'  {color("RSSI", "cyan")}:         {self.rssi}')
142            print(f'  {color("SID", "cyan")}:          {self.sync.sid}')
143
144            if self.manufacturer_data:
145                print(
146                    f'  {color("Manufacturer Data", "cyan")}: '
147                    f'{self.manufacturer_data[0]} -> {self.manufacturer_data[1].hex()}'
148                )
149
150            if self.broadcast_audio_announcement:
151                print(
152                    f'  {color("Broadcast ID", "cyan")}: '
153                    f'{self.broadcast_audio_announcement.broadcast_id}'
154                )
155
156            if self.public_broadcast_announcement:
157                print(
158                    f'  {color("Features", "cyan")}:     '
159                    f'{self.public_broadcast_announcement.features}'
160                )
161                print(
162                    f'  {color("Metadata", "cyan")}:     '
163                    f'{self.public_broadcast_announcement.metadata}'
164                )
165
166            if self.basic_audio_announcement:
167                print(color('  Audio:', 'cyan'))
168                print(
169                    color('    Presentation Delay:', 'magenta'),
170                    self.basic_audio_announcement.presentation_delay,
171                )
172                for subgroup in self.basic_audio_announcement.subgroups:
173                    print(color('    Subgroup:', 'magenta'))
174                    print(color('      Codec ID:', 'yellow'))
175                    print(
176                        color('        Coding Format:           ', 'green'),
177                        subgroup.codec_id.coding_format.name,
178                    )
179                    print(
180                        color('        Company ID:              ', 'green'),
181                        subgroup.codec_id.company_id,
182                    )
183                    print(
184                        color('        Vendor Specific Codec ID:', 'green'),
185                        subgroup.codec_id.vendor_specific_codec_id,
186                    )
187                    print(
188                        color('      Codec Config:', 'yellow'),
189                        subgroup.codec_specific_configuration,
190                    )
191                    print(color('      Metadata:    ', 'yellow'), subgroup.metadata)
192
193                    for bis in subgroup.bis:
194                        print(color(f'      BIS [{bis.index}]:', 'yellow'))
195                        print(
196                            color('       Codec Config:', 'green'),
197                            bis.codec_specific_configuration,
198                        )
199
200            if self.biginfo:
201                print(color('  BIG:', 'cyan'))
202                print(
203                    color('    Number of BIS:', 'magenta'),
204                    self.biginfo.num_bis,
205                )
206                print(
207                    color('    PHY:          ', 'magenta'),
208                    self.biginfo.phy.name,
209                )
210                print(
211                    color('    Framed:       ', 'magenta'),
212                    self.biginfo.framed,
213                )
214                print(
215                    color('    Encrypted:    ', 'magenta'),
216                    self.biginfo.encrypted,
217                )
218
219        def on_sync_establishment(self) -> None:
220            self.emit('sync_establishment')
221
222        def on_sync_loss(self) -> None:
223            self.basic_audio_announcement = None
224            self.biginfo = None
225            self.emit('sync_loss')
226
227        def on_periodic_advertisement(
228            self, advertisement: bumble.device.PeriodicAdvertisement
229        ) -> None:
230            if advertisement.data is None:
231                return
232
233            for service_data in advertisement.data.get_all(
234                bumble.core.AdvertisingData.SERVICE_DATA
235            ):
236                assert isinstance(service_data, tuple)
237                service_uuid, data = service_data
238                assert isinstance(data, bytes)
239
240                if service_uuid == bumble.gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE:
241                    self.basic_audio_announcement = (
242                        bumble.profiles.bap.BasicAudioAnnouncement.from_bytes(data)
243                    )
244                    break
245
246            self.emit('change')
247
248        def on_biginfo_advertisement(
249            self, advertisement: bumble.device.BIGInfoAdvertisement
250        ) -> None:
251            self.biginfo = advertisement
252            self.emit('change')
253
254    def __init__(
255        self,
256        device: bumble.device.Device,
257        filter_duplicates: bool,
258        sync_timeout: float,
259    ):
260        super().__init__()
261        self.device = device
262        self.filter_duplicates = filter_duplicates
263        self.sync_timeout = sync_timeout
264        self.broadcasts: Dict[bumble.hci.Address, BroadcastScanner.Broadcast] = {}
265        device.on('advertisement', self.on_advertisement)
266
267    async def start(self) -> None:
268        await self.device.start_scanning(
269            active=False,
270            filter_duplicates=False,
271        )
272
273    async def stop(self) -> None:
274        await self.device.stop_scanning()
275
276    def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
277        if (
278            broadcast_name := advertisement.data.get(
279                bumble.core.AdvertisingData.BROADCAST_NAME
280            )
281        ) is None:
282            return
283        assert isinstance(broadcast_name, str)
284
285        if broadcast := self.broadcasts.get(advertisement.address):
286            broadcast.update(advertisement)
287            return
288
289        bumble.utils.AsyncRunner.spawn(
290            self.on_new_broadcast(broadcast_name, advertisement)
291        )
292
293    async def on_new_broadcast(
294        self, name: str, advertisement: bumble.device.Advertisement
295    ) -> None:
296        periodic_advertising_sync = await self.device.create_periodic_advertising_sync(
297            advertiser_address=advertisement.address,
298            sid=advertisement.sid,
299            sync_timeout=self.sync_timeout,
300            filter_duplicates=self.filter_duplicates,
301        )
302        broadcast = self.Broadcast(
303            name,
304            periodic_advertising_sync,
305        )
306        broadcast.update(advertisement)
307        self.broadcasts[advertisement.address] = broadcast
308        periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast))
309        self.emit('new_broadcast', broadcast)
310
311    def on_broadcast_loss(self, broadcast: Broadcast) -> None:
312        del self.broadcasts[broadcast.sync.advertiser_address]
313        bumble.utils.AsyncRunner.spawn(broadcast.sync.terminate())
314        self.emit('broadcast_loss', broadcast)
315
316
317class PrintingBroadcastScanner:
318    def __init__(
319        self, device: bumble.device.Device, filter_duplicates: bool, sync_timeout: float
320    ) -> None:
321        self.scanner = BroadcastScanner(device, filter_duplicates, sync_timeout)
322        self.scanner.on('new_broadcast', self.on_new_broadcast)
323        self.scanner.on('broadcast_loss', self.on_broadcast_loss)
324        self.scanner.on('update', self.refresh)
325        self.status_message = ''
326
327    async def start(self) -> None:
328        self.status_message = color('Scanning...', 'green')
329        await self.scanner.start()
330
331    def on_new_broadcast(self, broadcast: BroadcastScanner.Broadcast) -> None:
332        self.status_message = color(
333            f'+Found {len(self.scanner.broadcasts)} broadcasts', 'green'
334        )
335        broadcast.on('change', self.refresh)
336        broadcast.on('update', self.refresh)
337        self.refresh()
338
339    def on_broadcast_loss(self, broadcast: BroadcastScanner.Broadcast) -> None:
340        self.status_message = color(
341            f'-Found {len(self.scanner.broadcasts)} broadcasts', 'green'
342        )
343        self.refresh()
344
345    def refresh(self) -> None:
346        # Clear the screen from the top
347        print('\033[H')
348        print('\033[0J')
349        print('\033[H')
350
351        # Print the status message
352        print(self.status_message)
353        print("==========================================")
354
355        # Print all broadcasts
356        for broadcast in self.scanner.broadcasts.values():
357            broadcast.print()
358            print('------------------------------------------')
359
360        # Clear the screen to the bottom
361        print('\033[0J')
362
363
364@contextlib.asynccontextmanager
365async def create_device(transport: str) -> AsyncGenerator[bumble.device.Device, Any]:
366    async with await bumble.transport.open_transport(transport) as (
367        hci_source,
368        hci_sink,
369    ):
370        device_config = bumble.device.DeviceConfiguration(
371            name=AURACAST_DEFAULT_DEVICE_NAME,
372            address=AURACAST_DEFAULT_DEVICE_ADDRESS,
373            keystore='JsonKeyStore',
374        )
375
376        device = bumble.device.Device.from_config_with_hci(
377            device_config,
378            hci_source,
379            hci_sink,
380        )
381        await device.power_on()
382
383        yield device
384
385
386async def find_broadcast_by_name(
387    device: bumble.device.Device, name: Optional[str]
388) -> BroadcastScanner.Broadcast:
389    result = asyncio.get_running_loop().create_future()
390
391    def on_broadcast_change(broadcast: BroadcastScanner.Broadcast) -> None:
392        if broadcast.basic_audio_announcement and not result.done():
393            print(color('Broadcast basic audio announcement received', 'green'))
394            result.set_result(broadcast)
395
396    def on_new_broadcast(broadcast: BroadcastScanner.Broadcast) -> None:
397        if name is None or broadcast.name == name:
398            print(color('Broadcast found:', 'green'), broadcast.name)
399            broadcast.on('change', lambda: on_broadcast_change(broadcast))
400            return
401
402        print(color(f'Skipping broadcast {broadcast.name}'))
403
404    scanner = BroadcastScanner(device, False, AURACAST_DEFAULT_SYNC_TIMEOUT)
405    scanner.on('new_broadcast', on_new_broadcast)
406    await scanner.start()
407
408    broadcast = await result
409    await scanner.stop()
410
411    return broadcast
412
413
414async def run_scan(
415    filter_duplicates: bool, sync_timeout: float, transport: str
416) -> None:
417    async with create_device(transport) as device:
418        if not device.supports_le_periodic_advertising:
419            print(color('Periodic advertising not supported', 'red'))
420            return
421
422        scanner = PrintingBroadcastScanner(device, filter_duplicates, sync_timeout)
423        await scanner.start()
424        await asyncio.get_running_loop().create_future()
425
426
427async def run_assist(
428    broadcast_name: Optional[str],
429    source_id: Optional[int],
430    command: str,
431    transport: str,
432    address: str,
433) -> None:
434    async with create_device(transport) as device:
435        if not device.supports_le_periodic_advertising:
436            print(color('Periodic advertising not supported', 'red'))
437            return
438
439        # Connect to the server
440        print(f'=== Connecting to {address}...')
441        connection = await device.connect(address)
442        peer = bumble.device.Peer(connection)
443        print(f'=== Connected to {peer}')
444
445        print("+++ Encrypting connection...")
446        await peer.connection.encrypt()
447        print("+++ Connection encrypted")
448
449        # Request a larger MTU
450        mtu = AURACAST_DEFAULT_ATT_MTU
451        print(color(f'$$$ Requesting MTU={mtu}', 'yellow'))
452        await peer.request_mtu(mtu)
453
454        # Get the BASS service
455        bass = await peer.discover_service_and_create_proxy(
456            bumble.profiles.bass.BroadcastAudioScanServiceProxy
457        )
458
459        # Check that the service was found
460        if not bass:
461            print(color('!!! Broadcast Audio Scan Service not found', 'red'))
462            return
463
464        # Subscribe to and read the broadcast receive state characteristics
465        for i, broadcast_receive_state in enumerate(bass.broadcast_receive_states):
466            try:
467                await broadcast_receive_state.subscribe(
468                    lambda value, i=i: print(
469                        f"{color(f'Broadcast Receive State Update [{i}]:', 'green')} {value}"
470                    )
471                )
472            except bumble.core.ProtocolError as error:
473                print(
474                    color(
475                        f'!!! Failed to subscribe to Broadcast Receive State characteristic:',
476                        'red',
477                    ),
478                    error,
479                )
480            value = await broadcast_receive_state.read_value()
481            print(
482                f'{color(f"Initial Broadcast Receive State [{i}]:", "green")} {value}'
483            )
484
485        if command == 'monitor-state':
486            await peer.sustain()
487            return
488
489        if command == 'add-source':
490            # Find the requested broadcast
491            await bass.remote_scan_started()
492            if broadcast_name:
493                print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
494            else:
495                print(color('Scanning for any broadcast', 'cyan'))
496            broadcast = await find_broadcast_by_name(device, broadcast_name)
497
498            if broadcast.broadcast_audio_announcement is None:
499                print(color('No broadcast audio announcement found', 'red'))
500                return
501
502            if (
503                broadcast.basic_audio_announcement is None
504                or not broadcast.basic_audio_announcement.subgroups
505            ):
506                print(color('No subgroups found', 'red'))
507                return
508
509            # Add the source
510            print(color('Adding source:', 'blue'), broadcast.sync.advertiser_address)
511            await bass.add_source(
512                broadcast.sync.advertiser_address,
513                broadcast.sync.sid,
514                broadcast.broadcast_audio_announcement.broadcast_id,
515                bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_AVAILABLE,
516                0xFFFF,
517                [
518                    bumble.profiles.bass.SubgroupInfo(
519                        bumble.profiles.bass.SubgroupInfo.ANY_BIS,
520                        bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
521                    )
522                ],
523            )
524
525            # Initiate a PA Sync Transfer
526            await broadcast.sync.transfer(peer.connection)
527
528            # Notify the sink that we're done scanning.
529            await bass.remote_scan_stopped()
530
531            await peer.sustain()
532            return
533
534        if command == 'modify-source':
535            if source_id is None:
536                print(color('!!! modify-source requires --source-id'))
537                return
538
539            # Find the requested broadcast
540            await bass.remote_scan_started()
541            if broadcast_name:
542                print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
543            else:
544                print(color('Scanning for any broadcast', 'cyan'))
545            broadcast = await find_broadcast_by_name(device, broadcast_name)
546
547            if broadcast.broadcast_audio_announcement is None:
548                print(color('No broadcast audio announcement found', 'red'))
549                return
550
551            if (
552                broadcast.basic_audio_announcement is None
553                or not broadcast.basic_audio_announcement.subgroups
554            ):
555                print(color('No subgroups found', 'red'))
556                return
557
558            # Modify the source
559            print(
560                color('Modifying source:', 'blue'),
561                source_id,
562            )
563            await bass.modify_source(
564                source_id,
565                bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
566                0xFFFF,
567                [
568                    bumble.profiles.bass.SubgroupInfo(
569                        bumble.profiles.bass.SubgroupInfo.ANY_BIS,
570                        bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
571                    )
572                ],
573            )
574            await peer.sustain()
575            return
576
577        if command == 'remove-source':
578            if source_id is None:
579                print(color('!!! remove-source requires --source-id'))
580                return
581
582            # Remove the source
583            print(color('Removing source:', 'blue'), source_id)
584            await bass.remove_source(source_id)
585            await peer.sustain()
586            return
587
588        print(color(f'!!! invalid command {command}'))
589
590
591async def run_pair(transport: str, address: str) -> None:
592    async with create_device(transport) as device:
593
594        # Connect to the server
595        print(f'=== Connecting to {address}...')
596        async with device.connect_as_gatt(address) as peer:
597            print(f'=== Connected to {peer}')
598
599            print("+++ Initiating pairing...")
600            await peer.connection.pair()
601            print("+++ Paired")
602
603
604def run_async(async_command: Coroutine) -> None:
605    try:
606        asyncio.run(async_command)
607    except bumble.core.ProtocolError as error:
608        if error.error_namespace == 'att' and error.error_code in list(
609            bumble.profiles.bass.ApplicationError
610        ):
611            message = bumble.profiles.bass.ApplicationError(error.error_code).name
612        else:
613            message = str(error)
614
615        print(
616            color('!!! An error occurred while executing the command:', 'red'), message
617        )
618
619
620# -----------------------------------------------------------------------------
621# Main
622# -----------------------------------------------------------------------------
623@click.group()
624@click.pass_context
625def auracast(
626    ctx,
627):
628    ctx.ensure_object(dict)
629
630
631@auracast.command('scan')
632@click.option(
633    '--filter-duplicates', is_flag=True, default=False, help='Filter duplicates'
634)
635@click.option(
636    '--sync-timeout',
637    metavar='SYNC_TIMEOUT',
638    type=float,
639    default=AURACAST_DEFAULT_SYNC_TIMEOUT,
640    help='Sync timeout (in seconds)',
641)
642@click.argument('transport')
643@click.pass_context
644def scan(ctx, filter_duplicates, sync_timeout, transport):
645    """Scan for public broadcasts"""
646    run_async(run_scan(filter_duplicates, sync_timeout, transport))
647
648
649@auracast.command('assist')
650@click.option(
651    '--broadcast-name',
652    metavar='BROADCAST_NAME',
653    help='Broadcast Name to tune to',
654)
655@click.option(
656    '--source-id',
657    metavar='SOURCE_ID',
658    type=int,
659    help='Source ID (for remove-source command)',
660)
661@click.option(
662    '--command',
663    type=click.Choice(
664        ['monitor-state', 'add-source', 'modify-source', 'remove-source']
665    ),
666    required=True,
667)
668@click.argument('transport')
669@click.argument('address')
670@click.pass_context
671def assist(ctx, broadcast_name, source_id, command, transport, address):
672    """Scan for broadcasts on behalf of a audio server"""
673    run_async(run_assist(broadcast_name, source_id, command, transport, address))
674
675
676@auracast.command('pair')
677@click.argument('transport')
678@click.argument('address')
679@click.pass_context
680def pair(ctx, transport, address):
681    """Pair with an audio server"""
682    run_async(run_pair(transport, address))
683
684
685def main():
686    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
687    auracast()
688
689
690# -----------------------------------------------------------------------------
691if __name__ == "__main__":
692    main()  # pylint: disable=no-value-for-parameter
693