1# Copyright 2024 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for
13
14"""LE Audio - Broadcast Audio Scan Service"""
15
16# -----------------------------------------------------------------------------
17# Imports
18# -----------------------------------------------------------------------------
19from __future__ import annotations
20import dataclasses
21import logging
22import struct
23from typing import ClassVar, List, Optional, Sequence
24
25from bumble import core
26from bumble import device
27from bumble import gatt
28from bumble import gatt_client
29from bumble import hci
30from bumble import utils
31
32# -----------------------------------------------------------------------------
33# Logging
34# -----------------------------------------------------------------------------
35logger = logging.getLogger(__name__)
36
37
38# -----------------------------------------------------------------------------
39# Constants
40# -----------------------------------------------------------------------------
41class ApplicationError(utils.OpenIntEnum):
42    OPCODE_NOT_SUPPORTED = 0x80
43    INVALID_SOURCE_ID = 0x81
44
45
46# -----------------------------------------------------------------------------
47def encode_subgroups(subgroups: Sequence[SubgroupInfo]) -> bytes:
48    return bytes([len(subgroups)]) + b"".join(
49        struct.pack("<IB", subgroup.bis_sync, len(subgroup.metadata))
50        + subgroup.metadata
51        for subgroup in subgroups
52    )
53
54
55def decode_subgroups(data: bytes) -> List[SubgroupInfo]:
56    num_subgroups = data[0]
57    offset = 1
58    subgroups = []
59    for _ in range(num_subgroups):
60        bis_sync = struct.unpack("<I", data[offset : offset + 4])[0]
61        metadata_length = data[offset + 4]
62        metadata = data[offset + 5 : offset + 5 + metadata_length]
63        offset += 5 + metadata_length
64        subgroups.append(SubgroupInfo(bis_sync, metadata))
65
66    return subgroups
67
68
69# -----------------------------------------------------------------------------
70class PeriodicAdvertisingSyncParams(utils.OpenIntEnum):
71    DO_NOT_SYNCHRONIZE_TO_PA = 0x00
72    SYNCHRONIZE_TO_PA_PAST_AVAILABLE = 0x01
73    SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE = 0x02
74
75
76@dataclasses.dataclass
77class SubgroupInfo:
78    ANY_BIS: ClassVar[int] = 0xFFFFFFFF
79
80    bis_sync: int
81    metadata: bytes
82
83
84class ControlPointOperation:
85    class OpCode(utils.OpenIntEnum):
86        REMOTE_SCAN_STOPPED = 0x00
87        REMOTE_SCAN_STARTED = 0x01
88        ADD_SOURCE = 0x02
89        MODIFY_SOURCE = 0x03
90        SET_BROADCAST_CODE = 0x04
91        REMOVE_SOURCE = 0x05
92
93    op_code: OpCode
94    parameters: bytes
95
96    @classmethod
97    def from_bytes(cls, data: bytes) -> ControlPointOperation:
98        op_code = data[0]
99
100        if op_code == cls.OpCode.REMOTE_SCAN_STOPPED:
101            return RemoteScanStoppedOperation()
102
103        if op_code == cls.OpCode.REMOTE_SCAN_STARTED:
104            return RemoteScanStartedOperation()
105
106        if op_code == cls.OpCode.ADD_SOURCE:
107            return AddSourceOperation.from_parameters(data[1:])
108
109        if op_code == cls.OpCode.MODIFY_SOURCE:
110            return ModifySourceOperation.from_parameters(data[1:])
111
112        if op_code == cls.OpCode.SET_BROADCAST_CODE:
113            return SetBroadcastCodeOperation.from_parameters(data[1:])
114
115        if op_code == cls.OpCode.REMOVE_SOURCE:
116            return RemoveSourceOperation.from_parameters(data[1:])
117
118        raise core.InvalidArgumentError("invalid op code")
119
120    def __init__(self, op_code: OpCode, parameters: bytes = b"") -> None:
121        self.op_code = op_code
122        self.parameters = parameters
123
124    def __bytes__(self) -> bytes:
125        return bytes([self.op_code]) + self.parameters
126
127
128class RemoteScanStoppedOperation(ControlPointOperation):
129    def __init__(self) -> None:
130        super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STOPPED)
131
132
133class RemoteScanStartedOperation(ControlPointOperation):
134    def __init__(self) -> None:
135        super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STARTED)
136
137
138class AddSourceOperation(ControlPointOperation):
139    @classmethod
140    def from_parameters(cls, parameters: bytes) -> AddSourceOperation:
141        instance = cls.__new__(cls)
142        instance.op_code = ControlPointOperation.OpCode.ADD_SOURCE
143        instance.parameters = parameters
144        instance.advertiser_address = hci.Address.parse_address_preceded_by_type(
145            parameters, 1
146        )[1]
147        instance.advertising_sid = parameters[7]
148        instance.broadcast_id = int.from_bytes(parameters[8:11], "little")
149        instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[11])
150        instance.pa_interval = struct.unpack("<H", parameters[12:14])[0]
151        instance.subgroups = decode_subgroups(parameters[14:])
152        return instance
153
154    def __init__(
155        self,
156        advertiser_address: hci.Address,
157        advertising_sid: int,
158        broadcast_id: int,
159        pa_sync: PeriodicAdvertisingSyncParams,
160        pa_interval: int,
161        subgroups: Sequence[SubgroupInfo],
162    ) -> None:
163        super().__init__(
164            ControlPointOperation.OpCode.ADD_SOURCE,
165            struct.pack(
166                "<B6sB3sBH",
167                advertiser_address.address_type,
168                bytes(advertiser_address),
169                advertising_sid,
170                broadcast_id.to_bytes(3, "little"),
171                pa_sync,
172                pa_interval,
173            )
174            + encode_subgroups(subgroups),
175        )
176        self.advertiser_address = advertiser_address
177        self.advertising_sid = advertising_sid
178        self.broadcast_id = broadcast_id
179        self.pa_sync = pa_sync
180        self.pa_interval = pa_interval
181        self.subgroups = list(subgroups)
182
183
184class ModifySourceOperation(ControlPointOperation):
185    @classmethod
186    def from_parameters(cls, parameters: bytes) -> ModifySourceOperation:
187        instance = cls.__new__(cls)
188        instance.op_code = ControlPointOperation.OpCode.MODIFY_SOURCE
189        instance.parameters = parameters
190        instance.source_id = parameters[0]
191        instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[1])
192        instance.pa_interval = struct.unpack("<H", parameters[2:4])[0]
193        instance.subgroups = decode_subgroups(parameters[4:])
194        return instance
195
196    def __init__(
197        self,
198        source_id: int,
199        pa_sync: PeriodicAdvertisingSyncParams,
200        pa_interval: int,
201        subgroups: Sequence[SubgroupInfo],
202    ) -> None:
203        super().__init__(
204            ControlPointOperation.OpCode.MODIFY_SOURCE,
205            struct.pack("<BBH", source_id, pa_sync, pa_interval)
206            + encode_subgroups(subgroups),
207        )
208        self.source_id = source_id
209        self.pa_sync = pa_sync
210        self.pa_interval = pa_interval
211        self.subgroups = list(subgroups)
212
213
214class SetBroadcastCodeOperation(ControlPointOperation):
215    @classmethod
216    def from_parameters(cls, parameters: bytes) -> SetBroadcastCodeOperation:
217        instance = cls.__new__(cls)
218        instance.op_code = ControlPointOperation.OpCode.SET_BROADCAST_CODE
219        instance.parameters = parameters
220        instance.source_id = parameters[0]
221        instance.broadcast_code = parameters[1:17]
222        return instance
223
224    def __init__(
225        self,
226        source_id: int,
227        broadcast_code: bytes,
228    ) -> None:
229        super().__init__(
230            ControlPointOperation.OpCode.SET_BROADCAST_CODE,
231            bytes([source_id]) + broadcast_code,
232        )
233        self.source_id = source_id
234        self.broadcast_code = broadcast_code
235
236        if len(self.broadcast_code) != 16:
237            raise core.InvalidArgumentError("broadcast_code must be 16 bytes")
238
239
240class RemoveSourceOperation(ControlPointOperation):
241    @classmethod
242    def from_parameters(cls, parameters: bytes) -> RemoveSourceOperation:
243        instance = cls.__new__(cls)
244        instance.op_code = ControlPointOperation.OpCode.REMOVE_SOURCE
245        instance.parameters = parameters
246        instance.source_id = parameters[0]
247        return instance
248
249    def __init__(self, source_id: int) -> None:
250        super().__init__(ControlPointOperation.OpCode.REMOVE_SOURCE, bytes([source_id]))
251        self.source_id = source_id
252
253
254@dataclasses.dataclass
255class BroadcastReceiveState:
256    class PeriodicAdvertisingSyncState(utils.OpenIntEnum):
257        NOT_SYNCHRONIZED_TO_PA = 0x00
258        SYNCINFO_REQUEST = 0x01
259        SYNCHRONIZED_TO_PA = 0x02
260        FAILED_TO_SYNCHRONIZE_TO_PA = 0x03
261        NO_PAST = 0x04
262
263    class BigEncryption(utils.OpenIntEnum):
264        NOT_ENCRYPTED = 0x00
265        BROADCAST_CODE_REQUIRED = 0x01
266        DECRYPTING = 0x02
267        BAD_CODE = 0x03
268
269    source_id: int
270    source_address: hci.Address
271    source_adv_sid: int
272    broadcast_id: int
273    pa_sync_state: PeriodicAdvertisingSyncState
274    big_encryption: BigEncryption
275    bad_code: bytes
276    subgroups: List[SubgroupInfo]
277
278    @classmethod
279    def from_bytes(cls, data: bytes) -> Optional[BroadcastReceiveState]:
280        if not data:
281            return None
282
283        source_id = data[0]
284        _, source_address = hci.Address.parse_address_preceded_by_type(data, 2)
285        source_adv_sid = data[8]
286        broadcast_id = int.from_bytes(data[9:12], "little")
287        pa_sync_state = cls.PeriodicAdvertisingSyncState(data[12])
288        big_encryption = cls.BigEncryption(data[13])
289        if big_encryption == cls.BigEncryption.BAD_CODE:
290            bad_code = data[14:30]
291            subgroups = decode_subgroups(data[30:])
292        else:
293            bad_code = b""
294            subgroups = decode_subgroups(data[14:])
295
296        return cls(
297            source_id,
298            source_address,
299            source_adv_sid,
300            broadcast_id,
301            pa_sync_state,
302            big_encryption,
303            bad_code,
304            subgroups,
305        )
306
307    def __bytes__(self) -> bytes:
308        return (
309            struct.pack(
310                "<BB6sB3sBB",
311                self.source_id,
312                self.source_address.address_type,
313                bytes(self.source_address),
314                self.source_adv_sid,
315                self.broadcast_id.to_bytes(3, "little"),
316                self.pa_sync_state,
317                self.big_encryption,
318            )
319            + self.bad_code
320            + encode_subgroups(self.subgroups)
321        )
322
323
324# -----------------------------------------------------------------------------
325class BroadcastAudioScanService(gatt.TemplateService):
326    UUID = gatt.GATT_BROADCAST_AUDIO_SCAN_SERVICE
327
328    def __init__(self):
329        self.broadcast_audio_scan_control_point_characteristic = gatt.Characteristic(
330            gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC,
331            gatt.Characteristic.Properties.WRITE
332            | gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
333            gatt.Characteristic.WRITEABLE,
334            gatt.CharacteristicValue(
335                write=self.on_broadcast_audio_scan_control_point_write
336            ),
337        )
338
339        self.broadcast_receive_state_characteristic = gatt.Characteristic(
340            gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC,
341            gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.NOTIFY,
342            gatt.Characteristic.Permissions.READABLE
343            | gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
344            b"12",  # TEST
345        )
346
347        super().__init__([self.battery_level_characteristic])
348
349    def on_broadcast_audio_scan_control_point_write(
350        self, connection: device.Connection, value: bytes
351    ) -> None:
352        pass
353
354
355# -----------------------------------------------------------------------------
356class BroadcastAudioScanServiceProxy(gatt_client.ProfileServiceProxy):
357    SERVICE_CLASS = BroadcastAudioScanService
358
359    broadcast_audio_scan_control_point: gatt_client.CharacteristicProxy
360    broadcast_receive_states: List[gatt.DelegatedCharacteristicAdapter]
361
362    def __init__(self, service_proxy: gatt_client.ServiceProxy):
363        self.service_proxy = service_proxy
364
365        if not (
366            characteristics := service_proxy.get_characteristics_by_uuid(
367                gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC
368            )
369        ):
370            raise gatt.InvalidServiceError(
371                "Broadcast Audio Scan Control Point characteristic not found"
372            )
373        self.broadcast_audio_scan_control_point = characteristics[0]
374
375        if not (
376            characteristics := service_proxy.get_characteristics_by_uuid(
377                gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC
378            )
379        ):
380            raise gatt.InvalidServiceError(
381                "Broadcast Receive State characteristic not found"
382            )
383        self.broadcast_receive_states = [
384            gatt.DelegatedCharacteristicAdapter(
385                characteristic, decode=BroadcastReceiveState.from_bytes
386            )
387            for characteristic in characteristics
388        ]
389
390    async def send_control_point_operation(
391        self, operation: ControlPointOperation
392    ) -> None:
393        await self.broadcast_audio_scan_control_point.write_value(
394            bytes(operation), with_response=True
395        )
396
397    async def remote_scan_started(self) -> None:
398        await self.send_control_point_operation(RemoteScanStartedOperation())
399
400    async def remote_scan_stopped(self) -> None:
401        await self.send_control_point_operation(RemoteScanStoppedOperation())
402
403    async def add_source(
404        self,
405        advertiser_address: hci.Address,
406        advertising_sid: int,
407        broadcast_id: int,
408        pa_sync: PeriodicAdvertisingSyncParams,
409        pa_interval: int,
410        subgroups: Sequence[SubgroupInfo],
411    ) -> None:
412        await self.send_control_point_operation(
413            AddSourceOperation(
414                advertiser_address,
415                advertising_sid,
416                broadcast_id,
417                pa_sync,
418                pa_interval,
419                subgroups,
420            )
421        )
422
423    async def modify_source(
424        self,
425        source_id: int,
426        pa_sync: PeriodicAdvertisingSyncParams,
427        pa_interval: int,
428        subgroups: Sequence[SubgroupInfo],
429    ) -> None:
430        await self.send_control_point_operation(
431            ModifySourceOperation(
432                source_id,
433                pa_sync,
434                pa_interval,
435                subgroups,
436            )
437        )
438
439    async def remove_source(self, source_id: int) -> None:
440        await self.send_control_point_operation(RemoveSourceOperation(source_id))
441