1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18import logging 19import asyncio 20from functools import partial 21 22from bumble.core import ( 23 BT_PERIPHERAL_ROLE, 24 BT_BR_EDR_TRANSPORT, 25 BT_LE_TRANSPORT, 26 InvalidStateError, 27) 28from bumble.colors import color 29from bumble.hci import ( 30 Address, 31 HCI_SUCCESS, 32 HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR, 33 HCI_CONNECTION_TIMEOUT_ERROR, 34 HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR, 35 HCI_PAGE_TIMEOUT_ERROR, 36 HCI_Connection_Complete_Event, 37) 38from bumble import controller 39 40from typing import Optional, Set 41 42# ----------------------------------------------------------------------------- 43# Logging 44# ----------------------------------------------------------------------------- 45logger = logging.getLogger(__name__) 46 47 48# ----------------------------------------------------------------------------- 49# Utils 50# ----------------------------------------------------------------------------- 51def parse_parameters(params_str): 52 result = {} 53 for param_str in params_str.split(','): 54 if '=' in param_str: 55 key, value = param_str.split('=') 56 result[key] = value 57 return result 58 59 60# ----------------------------------------------------------------------------- 61# TODO: add more support for various LL exchanges 62# (see Vol 6, Part B - 2.4 DATA CHANNEL PDU) 63# ----------------------------------------------------------------------------- 64class LocalLink: 65 ''' 66 Link bus for controllers to communicate with each other 67 ''' 68 69 controllers: Set[controller.Controller] 70 71 def __init__(self): 72 self.controllers = set() 73 self.pending_connection = None 74 self.pending_classic_connection = None 75 76 ############################################################ 77 # Common utils 78 ############################################################ 79 80 def add_controller(self, controller): 81 logger.debug(f'new controller: {controller}') 82 self.controllers.add(controller) 83 84 def remove_controller(self, controller): 85 self.controllers.remove(controller) 86 87 def find_controller(self, address): 88 for controller in self.controllers: 89 if controller.random_address == address: 90 return controller 91 return None 92 93 def find_classic_controller( 94 self, address: Address 95 ) -> Optional[controller.Controller]: 96 for controller in self.controllers: 97 if controller.public_address == address: 98 return controller 99 return None 100 101 def get_pending_connection(self): 102 return self.pending_connection 103 104 ############################################################ 105 # LE handlers 106 ############################################################ 107 108 def on_address_changed(self, controller): 109 pass 110 111 def send_advertising_data(self, sender_address, data): 112 # Send the advertising data to all controllers, except the sender 113 for controller in self.controllers: 114 if controller.random_address != sender_address: 115 controller.on_link_advertising_data(sender_address, data) 116 117 def send_acl_data(self, sender_controller, destination_address, transport, data): 118 # Send the data to the first controller with a matching address 119 if transport == BT_LE_TRANSPORT: 120 destination_controller = self.find_controller(destination_address) 121 source_address = sender_controller.random_address 122 elif transport == BT_BR_EDR_TRANSPORT: 123 destination_controller = self.find_classic_controller(destination_address) 124 source_address = sender_controller.public_address 125 126 if destination_controller is not None: 127 destination_controller.on_link_acl_data(source_address, transport, data) 128 129 def on_connection_complete(self): 130 # Check that we expect this call 131 if not self.pending_connection: 132 logger.warning('on_connection_complete with no pending connection') 133 return 134 135 central_address, le_create_connection_command = self.pending_connection 136 self.pending_connection = None 137 138 # Find the controller that initiated the connection 139 if not (central_controller := self.find_controller(central_address)): 140 logger.warning('!!! Initiating controller not found') 141 return 142 143 # Connect to the first controller with a matching address 144 if peripheral_controller := self.find_controller( 145 le_create_connection_command.peer_address 146 ): 147 central_controller.on_link_peripheral_connection_complete( 148 le_create_connection_command, HCI_SUCCESS 149 ) 150 peripheral_controller.on_link_central_connected(central_address) 151 return 152 153 # No peripheral found 154 central_controller.on_link_peripheral_connection_complete( 155 le_create_connection_command, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR 156 ) 157 158 def connect(self, central_address, le_create_connection_command): 159 logger.debug( 160 f'$$$ CONNECTION {central_address} -> ' 161 f'{le_create_connection_command.peer_address}' 162 ) 163 self.pending_connection = (central_address, le_create_connection_command) 164 asyncio.get_running_loop().call_soon(self.on_connection_complete) 165 166 def on_disconnection_complete( 167 self, central_address, peripheral_address, disconnect_command 168 ): 169 # Find the controller that initiated the disconnection 170 if not (central_controller := self.find_controller(central_address)): 171 logger.warning('!!! Initiating controller not found') 172 return 173 174 # Disconnect from the first controller with a matching address 175 if peripheral_controller := self.find_controller(peripheral_address): 176 peripheral_controller.on_link_central_disconnected( 177 central_address, disconnect_command.reason 178 ) 179 180 central_controller.on_link_peripheral_disconnection_complete( 181 disconnect_command, HCI_SUCCESS 182 ) 183 184 def disconnect(self, central_address, peripheral_address, disconnect_command): 185 logger.debug( 186 f'$$$ DISCONNECTION {central_address} -> ' 187 f'{peripheral_address}: reason = {disconnect_command.reason}' 188 ) 189 args = [central_address, peripheral_address, disconnect_command] 190 asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args) 191 192 # pylint: disable=too-many-arguments 193 def on_connection_encrypted( 194 self, central_address, peripheral_address, rand, ediv, ltk 195 ): 196 logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}') 197 198 if central_controller := self.find_controller(central_address): 199 central_controller.on_link_encrypted(peripheral_address, rand, ediv, ltk) 200 201 if peripheral_controller := self.find_controller(peripheral_address): 202 peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk) 203 204 def create_cis( 205 self, 206 central_controller: controller.Controller, 207 peripheral_address: Address, 208 cig_id: int, 209 cis_id: int, 210 ) -> None: 211 logger.debug( 212 f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}' 213 ) 214 if peripheral_controller := self.find_controller(peripheral_address): 215 asyncio.get_running_loop().call_soon( 216 peripheral_controller.on_link_cis_request, 217 central_controller.random_address, 218 cig_id, 219 cis_id, 220 ) 221 222 def accept_cis( 223 self, 224 peripheral_controller: controller.Controller, 225 central_address: Address, 226 cig_id: int, 227 cis_id: int, 228 ) -> None: 229 logger.debug( 230 f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}' 231 ) 232 if central_controller := self.find_controller(central_address): 233 asyncio.get_running_loop().call_soon( 234 central_controller.on_link_cis_established, cig_id, cis_id 235 ) 236 asyncio.get_running_loop().call_soon( 237 peripheral_controller.on_link_cis_established, cig_id, cis_id 238 ) 239 240 def disconnect_cis( 241 self, 242 initiator_controller: controller.Controller, 243 peer_address: Address, 244 cig_id: int, 245 cis_id: int, 246 ) -> None: 247 logger.debug( 248 f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}' 249 ) 250 if peer_controller := self.find_controller(peer_address): 251 asyncio.get_running_loop().call_soon( 252 initiator_controller.on_link_cis_disconnected, cig_id, cis_id 253 ) 254 asyncio.get_running_loop().call_soon( 255 peer_controller.on_link_cis_disconnected, cig_id, cis_id 256 ) 257 258 ############################################################ 259 # Classic handlers 260 ############################################################ 261 262 def classic_connect(self, initiator_controller, responder_address): 263 logger.debug( 264 f'[Classic] {initiator_controller.public_address} connects to {responder_address}' 265 ) 266 responder_controller = self.find_classic_controller(responder_address) 267 if responder_controller is None: 268 initiator_controller.on_classic_connection_complete( 269 responder_address, HCI_PAGE_TIMEOUT_ERROR 270 ) 271 return 272 self.pending_classic_connection = (initiator_controller, responder_controller) 273 274 responder_controller.on_classic_connection_request( 275 initiator_controller.public_address, 276 HCI_Connection_Complete_Event.ACL_LINK_TYPE, 277 ) 278 279 def classic_accept_connection( 280 self, responder_controller, initiator_address, responder_role 281 ): 282 logger.debug( 283 f'[Classic] {responder_controller.public_address} accepts to connect {initiator_address}' 284 ) 285 initiator_controller = self.find_classic_controller(initiator_address) 286 if initiator_controller is None: 287 responder_controller.on_classic_connection_complete( 288 responder_controller.public_address, HCI_PAGE_TIMEOUT_ERROR 289 ) 290 return 291 292 async def task(): 293 if responder_role != BT_PERIPHERAL_ROLE: 294 initiator_controller.on_classic_role_change( 295 responder_controller.public_address, int(not (responder_role)) 296 ) 297 initiator_controller.on_classic_connection_complete( 298 responder_controller.public_address, HCI_SUCCESS 299 ) 300 301 asyncio.create_task(task()) 302 responder_controller.on_classic_role_change( 303 initiator_controller.public_address, responder_role 304 ) 305 responder_controller.on_classic_connection_complete( 306 initiator_controller.public_address, HCI_SUCCESS 307 ) 308 self.pending_classic_connection = None 309 310 def classic_disconnect(self, initiator_controller, responder_address, reason): 311 logger.debug( 312 f'[Classic] {initiator_controller.public_address} disconnects {responder_address}' 313 ) 314 responder_controller = self.find_classic_controller(responder_address) 315 316 async def task(): 317 initiator_controller.on_classic_disconnected(responder_address, reason) 318 319 asyncio.create_task(task()) 320 responder_controller.on_classic_disconnected( 321 initiator_controller.public_address, reason 322 ) 323 324 def classic_switch_role( 325 self, initiator_controller, responder_address, initiator_new_role 326 ): 327 responder_controller = self.find_classic_controller(responder_address) 328 if responder_controller is None: 329 return 330 331 async def task(): 332 initiator_controller.on_classic_role_change( 333 responder_address, initiator_new_role 334 ) 335 336 asyncio.create_task(task()) 337 responder_controller.on_classic_role_change( 338 initiator_controller.public_address, int(not (initiator_new_role)) 339 ) 340 341 def classic_sco_connect( 342 self, 343 initiator_controller: controller.Controller, 344 responder_address: Address, 345 link_type: int, 346 ): 347 logger.debug( 348 f'[Classic] {initiator_controller.public_address} connects SCO to {responder_address}' 349 ) 350 responder_controller = self.find_classic_controller(responder_address) 351 # Initiator controller should handle it. 352 assert responder_controller 353 354 responder_controller.on_classic_connection_request( 355 initiator_controller.public_address, 356 link_type, 357 ) 358 359 def classic_accept_sco_connection( 360 self, 361 responder_controller: controller.Controller, 362 initiator_address: Address, 363 link_type: int, 364 ): 365 logger.debug( 366 f'[Classic] {responder_controller.public_address} accepts to connect SCO {initiator_address}' 367 ) 368 initiator_controller = self.find_classic_controller(initiator_address) 369 if initiator_controller is None: 370 responder_controller.on_classic_sco_connection_complete( 371 responder_controller.public_address, 372 HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR, 373 link_type, 374 ) 375 return 376 377 async def task(): 378 initiator_controller.on_classic_sco_connection_complete( 379 responder_controller.public_address, HCI_SUCCESS, link_type 380 ) 381 382 asyncio.create_task(task()) 383 responder_controller.on_classic_sco_connection_complete( 384 initiator_controller.public_address, HCI_SUCCESS, link_type 385 ) 386 387 388# ----------------------------------------------------------------------------- 389class RemoteLink: 390 ''' 391 A Link implementation that communicates with other virtual controllers via a 392 WebSocket relay 393 ''' 394 395 def __init__(self, uri): 396 self.controller = None 397 self.uri = uri 398 self.execution_queue = asyncio.Queue() 399 self.websocket = asyncio.get_running_loop().create_future() 400 self.rpc_result = None 401 self.pending_connection = None 402 self.central_connections = set() # List of addresses that we have connected to 403 self.peripheral_connections = ( 404 set() 405 ) # List of addresses that have connected to us 406 407 # Connect and run asynchronously 408 asyncio.create_task(self.run_connection()) 409 asyncio.create_task(self.run_executor_loop()) 410 411 def add_controller(self, controller): 412 if self.controller: 413 raise InvalidStateError('controller already set') 414 self.controller = controller 415 416 def remove_controller(self, controller): 417 if self.controller != controller: 418 raise InvalidStateError('controller mismatch') 419 self.controller = None 420 421 def get_pending_connection(self): 422 return self.pending_connection 423 424 def get_pending_classic_connection(self): 425 return self.pending_classic_connection 426 427 async def wait_until_connected(self): 428 await self.websocket 429 430 def execute(self, async_function): 431 self.execution_queue.put_nowait(async_function()) 432 433 async def run_executor_loop(self): 434 logger.debug('executor loop starting') 435 while True: 436 item = await self.execution_queue.get() 437 try: 438 await item 439 except Exception as error: 440 logger.warning( 441 f'{color("!!! Exception in async handler:", "red")} {error}' 442 ) 443 444 async def run_connection(self): 445 import websockets # lazy import 446 447 # Connect to the relay 448 logger.debug(f'connecting to {self.uri}') 449 # pylint: disable-next=no-member 450 websocket = await websockets.connect(self.uri) 451 self.websocket.set_result(websocket) 452 logger.debug(f'connected to {self.uri}') 453 454 while True: 455 message = await websocket.recv() 456 logger.debug(f'received message: {message}') 457 keyword, *payload = message.split(':', 1) 458 459 handler_name = f'on_{keyword}_received' 460 handler = getattr(self, handler_name, None) 461 if handler: 462 await handler(payload[0] if payload else None) 463 464 def close(self): 465 if self.websocket.done(): 466 logger.debug('closing websocket') 467 websocket = self.websocket.result() 468 asyncio.create_task(websocket.close()) 469 470 async def on_result_received(self, result): 471 if self.rpc_result: 472 self.rpc_result.set_result(result) 473 474 async def on_left_received(self, address): 475 if address in self.central_connections: 476 self.controller.on_link_peripheral_disconnected(Address(address)) 477 self.central_connections.remove(address) 478 479 if address in self.peripheral_connections: 480 self.controller.on_link_central_disconnected( 481 address, HCI_CONNECTION_TIMEOUT_ERROR 482 ) 483 self.peripheral_connections.remove(address) 484 485 async def on_unreachable_received(self, target): 486 await self.on_left_received(target) 487 488 async def on_message_received(self, message): 489 sender, *payload = message.split('/', 1) 490 if payload: 491 keyword, *payload = payload[0].split(':', 1) 492 handler_name = f'on_{keyword}_message_received' 493 handler = getattr(self, handler_name, None) 494 if handler: 495 await handler(sender, payload[0] if payload else None) 496 497 async def on_advertisement_message_received(self, sender, advertisement): 498 try: 499 self.controller.on_link_advertising_data( 500 Address(sender), bytes.fromhex(advertisement) 501 ) 502 except Exception: 503 logger.exception('exception') 504 505 async def on_acl_message_received(self, sender, acl_data): 506 try: 507 self.controller.on_link_acl_data(Address(sender), bytes.fromhex(acl_data)) 508 except Exception: 509 logger.exception('exception') 510 511 async def on_connect_message_received(self, sender, _): 512 # Remember the connection 513 self.peripheral_connections.add(sender) 514 515 # Notify the controller 516 logger.debug(f'connection from central {sender}') 517 self.controller.on_link_central_connected(Address(sender)) 518 519 # Accept the connection by responding to it 520 await self.send_targeted_message(sender, 'connected') 521 522 async def on_connected_message_received(self, sender, _): 523 if not self.pending_connection: 524 logger.warning('received a connection ack, but no connection is pending') 525 return 526 527 # Remember the connection 528 self.central_connections.add(sender) 529 530 # Notify the controller 531 logger.debug(f'connected to peripheral {self.pending_connection.peer_address}') 532 self.controller.on_link_peripheral_connection_complete( 533 self.pending_connection, HCI_SUCCESS 534 ) 535 536 async def on_disconnect_message_received(self, sender, message): 537 # Notify the controller 538 params = parse_parameters(message) 539 reason = int(params.get('reason', str(HCI_CONNECTION_TIMEOUT_ERROR))) 540 self.controller.on_link_central_disconnected(Address(sender), reason) 541 542 # Forget the connection 543 if sender in self.peripheral_connections: 544 self.peripheral_connections.remove(sender) 545 546 async def on_encrypted_message_received(self, sender, _): 547 # TODO parse params to get real args 548 self.controller.on_link_encrypted(Address(sender), bytes(8), 0, bytes(16)) 549 550 async def send_rpc_command(self, command): 551 # Ensure we have a connection 552 websocket = await self.websocket 553 554 # Create a future value to hold the eventual result 555 assert self.rpc_result is None 556 self.rpc_result = asyncio.get_running_loop().create_future() 557 558 # Send the command 559 await websocket.send(command) 560 561 # Wait for the result 562 rpc_result = await self.rpc_result 563 self.rpc_result = None 564 logger.debug(f'rpc_result: {rpc_result}') 565 566 # TODO: parse the result 567 568 async def send_targeted_message(self, target, message): 569 # Ensure we have a connection 570 websocket = await self.websocket 571 572 # Send the message 573 await websocket.send(f'@{target} {message}') 574 575 async def notify_address_changed(self): 576 await self.send_rpc_command(f'/set-address {self.controller.random_address}') 577 578 def on_address_changed(self, controller): 579 logger.info(f'address changed for {controller}: {controller.random_address}') 580 581 # Notify the relay of the change 582 self.execute(self.notify_address_changed) 583 584 async def send_advertising_data_to_relay(self, data): 585 await self.send_targeted_message('*', f'advertisement:{data.hex()}') 586 587 def send_advertising_data(self, _, data): 588 self.execute(partial(self.send_advertising_data_to_relay, data)) 589 590 async def send_acl_data_to_relay(self, peer_address, data): 591 await self.send_targeted_message(peer_address, f'acl:{data.hex()}') 592 593 def send_acl_data(self, _, peer_address, _transport, data): 594 # TODO: handle different transport 595 self.execute(partial(self.send_acl_data_to_relay, peer_address, data)) 596 597 async def send_connection_request_to_relay(self, peer_address): 598 await self.send_targeted_message(peer_address, 'connect') 599 600 def connect(self, _, le_create_connection_command): 601 if self.pending_connection: 602 logger.warning('connection already pending') 603 return 604 self.pending_connection = le_create_connection_command 605 self.execute( 606 partial( 607 self.send_connection_request_to_relay, 608 str(le_create_connection_command.peer_address), 609 ) 610 ) 611 612 def on_disconnection_complete(self, disconnect_command): 613 self.controller.on_link_peripheral_disconnection_complete( 614 disconnect_command, HCI_SUCCESS 615 ) 616 617 def disconnect(self, central_address, peripheral_address, disconnect_command): 618 logger.debug( 619 f'disconnect {central_address} -> ' 620 f'{peripheral_address}: reason = {disconnect_command.reason}' 621 ) 622 self.execute( 623 partial( 624 self.send_targeted_message, 625 peripheral_address, 626 f'disconnect:reason={disconnect_command.reason}', 627 ) 628 ) 629 asyncio.get_running_loop().call_soon( 630 self.on_disconnection_complete, disconnect_command 631 ) 632 633 def on_connection_encrypted(self, _, peripheral_address, rand, ediv, ltk): 634 asyncio.get_running_loop().call_soon( 635 self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk 636 ) 637 self.execute( 638 partial( 639 self.send_targeted_message, 640 peripheral_address, 641 f'encrypted:ltk={ltk.hex()}', 642 ) 643 ) 644