1# Copyright 2024 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for 13 14"""LE Audio - Broadcast Audio Scan Service""" 15 16# ----------------------------------------------------------------------------- 17# Imports 18# ----------------------------------------------------------------------------- 19from __future__ import annotations 20import dataclasses 21import logging 22import struct 23from typing import ClassVar, List, Optional, Sequence 24 25from bumble import core 26from bumble import device 27from bumble import gatt 28from bumble import gatt_client 29from bumble import hci 30from bumble import utils 31 32# ----------------------------------------------------------------------------- 33# Logging 34# ----------------------------------------------------------------------------- 35logger = logging.getLogger(__name__) 36 37 38# ----------------------------------------------------------------------------- 39# Constants 40# ----------------------------------------------------------------------------- 41class ApplicationError(utils.OpenIntEnum): 42 OPCODE_NOT_SUPPORTED = 0x80 43 INVALID_SOURCE_ID = 0x81 44 45 46# ----------------------------------------------------------------------------- 47def encode_subgroups(subgroups: Sequence[SubgroupInfo]) -> bytes: 48 return bytes([len(subgroups)]) + b"".join( 49 struct.pack("<IB", subgroup.bis_sync, len(subgroup.metadata)) 50 + subgroup.metadata 51 for subgroup in subgroups 52 ) 53 54 55def decode_subgroups(data: bytes) -> List[SubgroupInfo]: 56 num_subgroups = data[0] 57 offset = 1 58 subgroups = [] 59 for _ in range(num_subgroups): 60 bis_sync = struct.unpack("<I", data[offset : offset + 4])[0] 61 metadata_length = data[offset + 4] 62 metadata = data[offset + 5 : offset + 5 + metadata_length] 63 offset += 5 + metadata_length 64 subgroups.append(SubgroupInfo(bis_sync, metadata)) 65 66 return subgroups 67 68 69# ----------------------------------------------------------------------------- 70class PeriodicAdvertisingSyncParams(utils.OpenIntEnum): 71 DO_NOT_SYNCHRONIZE_TO_PA = 0x00 72 SYNCHRONIZE_TO_PA_PAST_AVAILABLE = 0x01 73 SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE = 0x02 74 75 76@dataclasses.dataclass 77class SubgroupInfo: 78 ANY_BIS: ClassVar[int] = 0xFFFFFFFF 79 80 bis_sync: int 81 metadata: bytes 82 83 84class ControlPointOperation: 85 class OpCode(utils.OpenIntEnum): 86 REMOTE_SCAN_STOPPED = 0x00 87 REMOTE_SCAN_STARTED = 0x01 88 ADD_SOURCE = 0x02 89 MODIFY_SOURCE = 0x03 90 SET_BROADCAST_CODE = 0x04 91 REMOVE_SOURCE = 0x05 92 93 op_code: OpCode 94 parameters: bytes 95 96 @classmethod 97 def from_bytes(cls, data: bytes) -> ControlPointOperation: 98 op_code = data[0] 99 100 if op_code == cls.OpCode.REMOTE_SCAN_STOPPED: 101 return RemoteScanStoppedOperation() 102 103 if op_code == cls.OpCode.REMOTE_SCAN_STARTED: 104 return RemoteScanStartedOperation() 105 106 if op_code == cls.OpCode.ADD_SOURCE: 107 return AddSourceOperation.from_parameters(data[1:]) 108 109 if op_code == cls.OpCode.MODIFY_SOURCE: 110 return ModifySourceOperation.from_parameters(data[1:]) 111 112 if op_code == cls.OpCode.SET_BROADCAST_CODE: 113 return SetBroadcastCodeOperation.from_parameters(data[1:]) 114 115 if op_code == cls.OpCode.REMOVE_SOURCE: 116 return RemoveSourceOperation.from_parameters(data[1:]) 117 118 raise core.InvalidArgumentError("invalid op code") 119 120 def __init__(self, op_code: OpCode, parameters: bytes = b"") -> None: 121 self.op_code = op_code 122 self.parameters = parameters 123 124 def __bytes__(self) -> bytes: 125 return bytes([self.op_code]) + self.parameters 126 127 128class RemoteScanStoppedOperation(ControlPointOperation): 129 def __init__(self) -> None: 130 super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STOPPED) 131 132 133class RemoteScanStartedOperation(ControlPointOperation): 134 def __init__(self) -> None: 135 super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STARTED) 136 137 138class AddSourceOperation(ControlPointOperation): 139 @classmethod 140 def from_parameters(cls, parameters: bytes) -> AddSourceOperation: 141 instance = cls.__new__(cls) 142 instance.op_code = ControlPointOperation.OpCode.ADD_SOURCE 143 instance.parameters = parameters 144 instance.advertiser_address = hci.Address.parse_address_preceded_by_type( 145 parameters, 1 146 )[1] 147 instance.advertising_sid = parameters[7] 148 instance.broadcast_id = int.from_bytes(parameters[8:11], "little") 149 instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[11]) 150 instance.pa_interval = struct.unpack("<H", parameters[12:14])[0] 151 instance.subgroups = decode_subgroups(parameters[14:]) 152 return instance 153 154 def __init__( 155 self, 156 advertiser_address: hci.Address, 157 advertising_sid: int, 158 broadcast_id: int, 159 pa_sync: PeriodicAdvertisingSyncParams, 160 pa_interval: int, 161 subgroups: Sequence[SubgroupInfo], 162 ) -> None: 163 super().__init__( 164 ControlPointOperation.OpCode.ADD_SOURCE, 165 struct.pack( 166 "<B6sB3sBH", 167 advertiser_address.address_type, 168 bytes(advertiser_address), 169 advertising_sid, 170 broadcast_id.to_bytes(3, "little"), 171 pa_sync, 172 pa_interval, 173 ) 174 + encode_subgroups(subgroups), 175 ) 176 self.advertiser_address = advertiser_address 177 self.advertising_sid = advertising_sid 178 self.broadcast_id = broadcast_id 179 self.pa_sync = pa_sync 180 self.pa_interval = pa_interval 181 self.subgroups = list(subgroups) 182 183 184class ModifySourceOperation(ControlPointOperation): 185 @classmethod 186 def from_parameters(cls, parameters: bytes) -> ModifySourceOperation: 187 instance = cls.__new__(cls) 188 instance.op_code = ControlPointOperation.OpCode.MODIFY_SOURCE 189 instance.parameters = parameters 190 instance.source_id = parameters[0] 191 instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[1]) 192 instance.pa_interval = struct.unpack("<H", parameters[2:4])[0] 193 instance.subgroups = decode_subgroups(parameters[4:]) 194 return instance 195 196 def __init__( 197 self, 198 source_id: int, 199 pa_sync: PeriodicAdvertisingSyncParams, 200 pa_interval: int, 201 subgroups: Sequence[SubgroupInfo], 202 ) -> None: 203 super().__init__( 204 ControlPointOperation.OpCode.MODIFY_SOURCE, 205 struct.pack("<BBH", source_id, pa_sync, pa_interval) 206 + encode_subgroups(subgroups), 207 ) 208 self.source_id = source_id 209 self.pa_sync = pa_sync 210 self.pa_interval = pa_interval 211 self.subgroups = list(subgroups) 212 213 214class SetBroadcastCodeOperation(ControlPointOperation): 215 @classmethod 216 def from_parameters(cls, parameters: bytes) -> SetBroadcastCodeOperation: 217 instance = cls.__new__(cls) 218 instance.op_code = ControlPointOperation.OpCode.SET_BROADCAST_CODE 219 instance.parameters = parameters 220 instance.source_id = parameters[0] 221 instance.broadcast_code = parameters[1:17] 222 return instance 223 224 def __init__( 225 self, 226 source_id: int, 227 broadcast_code: bytes, 228 ) -> None: 229 super().__init__( 230 ControlPointOperation.OpCode.SET_BROADCAST_CODE, 231 bytes([source_id]) + broadcast_code, 232 ) 233 self.source_id = source_id 234 self.broadcast_code = broadcast_code 235 236 if len(self.broadcast_code) != 16: 237 raise core.InvalidArgumentError("broadcast_code must be 16 bytes") 238 239 240class RemoveSourceOperation(ControlPointOperation): 241 @classmethod 242 def from_parameters(cls, parameters: bytes) -> RemoveSourceOperation: 243 instance = cls.__new__(cls) 244 instance.op_code = ControlPointOperation.OpCode.REMOVE_SOURCE 245 instance.parameters = parameters 246 instance.source_id = parameters[0] 247 return instance 248 249 def __init__(self, source_id: int) -> None: 250 super().__init__(ControlPointOperation.OpCode.REMOVE_SOURCE, bytes([source_id])) 251 self.source_id = source_id 252 253 254@dataclasses.dataclass 255class BroadcastReceiveState: 256 class PeriodicAdvertisingSyncState(utils.OpenIntEnum): 257 NOT_SYNCHRONIZED_TO_PA = 0x00 258 SYNCINFO_REQUEST = 0x01 259 SYNCHRONIZED_TO_PA = 0x02 260 FAILED_TO_SYNCHRONIZE_TO_PA = 0x03 261 NO_PAST = 0x04 262 263 class BigEncryption(utils.OpenIntEnum): 264 NOT_ENCRYPTED = 0x00 265 BROADCAST_CODE_REQUIRED = 0x01 266 DECRYPTING = 0x02 267 BAD_CODE = 0x03 268 269 source_id: int 270 source_address: hci.Address 271 source_adv_sid: int 272 broadcast_id: int 273 pa_sync_state: PeriodicAdvertisingSyncState 274 big_encryption: BigEncryption 275 bad_code: bytes 276 subgroups: List[SubgroupInfo] 277 278 @classmethod 279 def from_bytes(cls, data: bytes) -> Optional[BroadcastReceiveState]: 280 if not data: 281 return None 282 283 source_id = data[0] 284 _, source_address = hci.Address.parse_address_preceded_by_type(data, 2) 285 source_adv_sid = data[8] 286 broadcast_id = int.from_bytes(data[9:12], "little") 287 pa_sync_state = cls.PeriodicAdvertisingSyncState(data[12]) 288 big_encryption = cls.BigEncryption(data[13]) 289 if big_encryption == cls.BigEncryption.BAD_CODE: 290 bad_code = data[14:30] 291 subgroups = decode_subgroups(data[30:]) 292 else: 293 bad_code = b"" 294 subgroups = decode_subgroups(data[14:]) 295 296 return cls( 297 source_id, 298 source_address, 299 source_adv_sid, 300 broadcast_id, 301 pa_sync_state, 302 big_encryption, 303 bad_code, 304 subgroups, 305 ) 306 307 def __bytes__(self) -> bytes: 308 return ( 309 struct.pack( 310 "<BB6sB3sBB", 311 self.source_id, 312 self.source_address.address_type, 313 bytes(self.source_address), 314 self.source_adv_sid, 315 self.broadcast_id.to_bytes(3, "little"), 316 self.pa_sync_state, 317 self.big_encryption, 318 ) 319 + self.bad_code 320 + encode_subgroups(self.subgroups) 321 ) 322 323 324# ----------------------------------------------------------------------------- 325class BroadcastAudioScanService(gatt.TemplateService): 326 UUID = gatt.GATT_BROADCAST_AUDIO_SCAN_SERVICE 327 328 def __init__(self): 329 self.broadcast_audio_scan_control_point_characteristic = gatt.Characteristic( 330 gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC, 331 gatt.Characteristic.Properties.WRITE 332 | gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE, 333 gatt.Characteristic.WRITEABLE, 334 gatt.CharacteristicValue( 335 write=self.on_broadcast_audio_scan_control_point_write 336 ), 337 ) 338 339 self.broadcast_receive_state_characteristic = gatt.Characteristic( 340 gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC, 341 gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.NOTIFY, 342 gatt.Characteristic.Permissions.READABLE 343 | gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION, 344 b"12", # TEST 345 ) 346 347 super().__init__([self.battery_level_characteristic]) 348 349 def on_broadcast_audio_scan_control_point_write( 350 self, connection: device.Connection, value: bytes 351 ) -> None: 352 pass 353 354 355# ----------------------------------------------------------------------------- 356class BroadcastAudioScanServiceProxy(gatt_client.ProfileServiceProxy): 357 SERVICE_CLASS = BroadcastAudioScanService 358 359 broadcast_audio_scan_control_point: gatt_client.CharacteristicProxy 360 broadcast_receive_states: List[gatt.DelegatedCharacteristicAdapter] 361 362 def __init__(self, service_proxy: gatt_client.ServiceProxy): 363 self.service_proxy = service_proxy 364 365 if not ( 366 characteristics := service_proxy.get_characteristics_by_uuid( 367 gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC 368 ) 369 ): 370 raise gatt.InvalidServiceError( 371 "Broadcast Audio Scan Control Point characteristic not found" 372 ) 373 self.broadcast_audio_scan_control_point = characteristics[0] 374 375 if not ( 376 characteristics := service_proxy.get_characteristics_by_uuid( 377 gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC 378 ) 379 ): 380 raise gatt.InvalidServiceError( 381 "Broadcast Receive State characteristic not found" 382 ) 383 self.broadcast_receive_states = [ 384 gatt.DelegatedCharacteristicAdapter( 385 characteristic, decode=BroadcastReceiveState.from_bytes 386 ) 387 for characteristic in characteristics 388 ] 389 390 async def send_control_point_operation( 391 self, operation: ControlPointOperation 392 ) -> None: 393 await self.broadcast_audio_scan_control_point.write_value( 394 bytes(operation), with_response=True 395 ) 396 397 async def remote_scan_started(self) -> None: 398 await self.send_control_point_operation(RemoteScanStartedOperation()) 399 400 async def remote_scan_stopped(self) -> None: 401 await self.send_control_point_operation(RemoteScanStoppedOperation()) 402 403 async def add_source( 404 self, 405 advertiser_address: hci.Address, 406 advertising_sid: int, 407 broadcast_id: int, 408 pa_sync: PeriodicAdvertisingSyncParams, 409 pa_interval: int, 410 subgroups: Sequence[SubgroupInfo], 411 ) -> None: 412 await self.send_control_point_operation( 413 AddSourceOperation( 414 advertiser_address, 415 advertising_sid, 416 broadcast_id, 417 pa_sync, 418 pa_interval, 419 subgroups, 420 ) 421 ) 422 423 async def modify_source( 424 self, 425 source_id: int, 426 pa_sync: PeriodicAdvertisingSyncParams, 427 pa_interval: int, 428 subgroups: Sequence[SubgroupInfo], 429 ) -> None: 430 await self.send_control_point_operation( 431 ModifySourceOperation( 432 source_id, 433 pa_sync, 434 pa_interval, 435 subgroups, 436 ) 437 ) 438 439 async def remove_source(self, source_id: int) -> None: 440 await self.send_control_point_operation(RemoveSourceOperation(source_id)) 441