1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18import asyncio 19import os 20import logging 21import click 22from prompt_toolkit.shortcuts import PromptSession 23 24from bumble.colors import color 25from bumble.device import Device, Peer 26from bumble.transport import open_transport_or_link 27from bumble.pairing import OobData, PairingDelegate, PairingConfig 28from bumble.smp import OobContext, OobLegacyContext 29from bumble.smp import error_name as smp_error_name 30from bumble.keys import JsonKeyStore 31from bumble.core import ( 32 AdvertisingData, 33 ProtocolError, 34 BT_LE_TRANSPORT, 35 BT_BR_EDR_TRANSPORT, 36) 37from bumble.gatt import ( 38 GATT_DEVICE_NAME_CHARACTERISTIC, 39 GATT_GENERIC_ACCESS_SERVICE, 40 Service, 41 Characteristic, 42 CharacteristicValue, 43) 44from bumble.att import ( 45 ATT_Error, 46 ATT_INSUFFICIENT_AUTHENTICATION_ERROR, 47 ATT_INSUFFICIENT_ENCRYPTION_ERROR, 48) 49from bumble.utils import AsyncRunner 50 51# ----------------------------------------------------------------------------- 52# Constants 53# ----------------------------------------------------------------------------- 54POST_PAIRING_DELAY = 1 55 56 57# ----------------------------------------------------------------------------- 58class Waiter: 59 instance = None 60 61 def __init__(self, linger=False): 62 self.done = asyncio.get_running_loop().create_future() 63 self.linger = linger 64 65 def terminate(self): 66 if not self.linger: 67 self.done.set_result(None) 68 69 async def wait_until_terminated(self): 70 return await self.done 71 72 73# ----------------------------------------------------------------------------- 74class Delegate(PairingDelegate): 75 def __init__(self, mode, connection, capability_string, do_prompt): 76 super().__init__( 77 io_capability={ 78 'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY, 79 'display': PairingDelegate.DISPLAY_OUTPUT_ONLY, 80 'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT, 81 'display+yes/no': PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT, 82 'none': PairingDelegate.NO_OUTPUT_NO_INPUT, 83 }[capability_string.lower()] 84 ) 85 86 self.mode = mode 87 self.peer = Peer(connection) 88 self.peer_name = None 89 self.do_prompt = do_prompt 90 91 def print(self, message): 92 print(color(message, 'yellow')) 93 94 async def prompt(self, message): 95 # Wait a bit to allow some of the log lines to print before we prompt 96 await asyncio.sleep(1) 97 98 session = PromptSession(message) 99 response = await session.prompt_async() 100 return response.lower().strip() 101 102 async def update_peer_name(self): 103 if self.peer_name is not None: 104 # We already asked the peer 105 return 106 107 # Try to get the peer's name 108 if self.peer: 109 peer_name = await get_peer_name(self.peer, self.mode) 110 self.peer_name = f'{peer_name or ""} [{self.peer.connection.peer_address}]' 111 else: 112 self.peer_name = '[?]' 113 114 async def accept(self): 115 if self.do_prompt: 116 await self.update_peer_name() 117 118 # Prompt for acceptance 119 self.print('###-----------------------------------') 120 self.print(f'### Pairing request from {self.peer_name}') 121 self.print('###-----------------------------------') 122 while True: 123 response = await self.prompt('>>> Accept? ') 124 125 if response == 'yes': 126 return True 127 128 if response == 'no': 129 return False 130 131 # Accept silently 132 return True 133 134 async def compare_numbers(self, number, digits): 135 await self.update_peer_name() 136 137 # Prompt for a numeric comparison 138 self.print('###-----------------------------------') 139 self.print(f'### Pairing with {self.peer_name}') 140 self.print('###-----------------------------------') 141 while True: 142 response = await self.prompt( 143 f'>>> Does the other device display {number:0{digits}}? ' 144 ) 145 146 if response == 'yes': 147 return True 148 149 if response == 'no': 150 return False 151 152 async def get_number(self): 153 await self.update_peer_name() 154 155 # Prompt for a PIN 156 while True: 157 try: 158 self.print('###-----------------------------------') 159 self.print(f'### Pairing with {self.peer_name}') 160 self.print('###-----------------------------------') 161 return int(await self.prompt('>>> Enter PIN: ')) 162 except ValueError: 163 pass 164 165 async def display_number(self, number, digits): 166 await self.update_peer_name() 167 168 # Display a PIN code 169 self.print('###-----------------------------------') 170 self.print(f'### Pairing with {self.peer_name}') 171 self.print(f'### PIN: {number:0{digits}}') 172 self.print('###-----------------------------------') 173 174 async def get_string(self, max_length: int): 175 await self.update_peer_name() 176 177 # Prompt a PIN (for legacy pairing in classic) 178 self.print('###-----------------------------------') 179 self.print(f'### Pairing with {self.peer_name}') 180 self.print('###-----------------------------------') 181 count = 0 182 while True: 183 response = await self.prompt('>>> Enter PIN (1-6 chars):') 184 if len(response) == 0: 185 count += 1 186 if count > 3: 187 self.print('too many tries, stopping the pairing') 188 return None 189 190 self.print('no PIN was entered, try again') 191 continue 192 return response 193 194 195# ----------------------------------------------------------------------------- 196async def get_peer_name(peer, mode): 197 if mode == 'classic': 198 return await peer.request_name() 199 200 # Try to get the peer name from GATT 201 services = await peer.discover_service(GATT_GENERIC_ACCESS_SERVICE) 202 if not services: 203 return None 204 205 values = await peer.read_characteristics_by_uuid( 206 GATT_DEVICE_NAME_CHARACTERISTIC, services[0] 207 ) 208 if values: 209 return values[0].decode('utf-8') 210 211 return None 212 213 214# ----------------------------------------------------------------------------- 215AUTHENTICATION_ERROR_RETURNED = [False, False] 216 217 218def read_with_error(connection): 219 if not connection.is_encrypted: 220 raise ATT_Error(ATT_INSUFFICIENT_ENCRYPTION_ERROR) 221 222 if AUTHENTICATION_ERROR_RETURNED[0]: 223 return bytes([1]) 224 225 AUTHENTICATION_ERROR_RETURNED[0] = True 226 raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR) 227 228 229def write_with_error(connection, _value): 230 if not connection.is_encrypted: 231 raise ATT_Error(ATT_INSUFFICIENT_ENCRYPTION_ERROR) 232 233 if not AUTHENTICATION_ERROR_RETURNED[1]: 234 AUTHENTICATION_ERROR_RETURNED[1] = True 235 raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR) 236 237 238# ----------------------------------------------------------------------------- 239def on_connection(connection, request): 240 print(color(f'<<< Connection: {connection}', 'green')) 241 242 # Listen for pairing events 243 connection.on('pairing_start', on_pairing_start) 244 connection.on('pairing', lambda keys: on_pairing(connection, keys)) 245 connection.on( 246 'pairing_failure', lambda reason: on_pairing_failure(connection, reason) 247 ) 248 249 # Listen for encryption changes 250 connection.on( 251 'connection_encryption_change', 252 lambda: on_connection_encryption_change(connection), 253 ) 254 255 # Request pairing if needed 256 if request: 257 print(color('>>> Requesting pairing', 'green')) 258 connection.request_pairing() 259 260 261# ----------------------------------------------------------------------------- 262def on_connection_encryption_change(connection): 263 print(color('@@@-----------------------------------', 'blue')) 264 print( 265 color( 266 f'@@@ Connection is {"" if connection.is_encrypted else "not"}encrypted', 267 'blue', 268 ) 269 ) 270 print(color('@@@-----------------------------------', 'blue')) 271 272 273# ----------------------------------------------------------------------------- 274def on_pairing_start(): 275 print(color('***-----------------------------------', 'magenta')) 276 print(color('*** Pairing starting', 'magenta')) 277 print(color('***-----------------------------------', 'magenta')) 278 279 280# ----------------------------------------------------------------------------- 281@AsyncRunner.run_in_task() 282async def on_pairing(connection, keys): 283 print(color('***-----------------------------------', 'cyan')) 284 print(color(f'*** Paired! (peer identity={connection.peer_address})', 'cyan')) 285 keys.print(prefix=color('*** ', 'cyan')) 286 print(color('***-----------------------------------', 'cyan')) 287 await asyncio.sleep(POST_PAIRING_DELAY) 288 await connection.disconnect() 289 Waiter.instance.terminate() 290 291 292# ----------------------------------------------------------------------------- 293@AsyncRunner.run_in_task() 294async def on_pairing_failure(connection, reason): 295 print(color('***-----------------------------------', 'red')) 296 print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red')) 297 print(color('***-----------------------------------', 'red')) 298 await connection.disconnect() 299 Waiter.instance.terminate() 300 301 302# ----------------------------------------------------------------------------- 303async def pair( 304 mode, 305 sc, 306 mitm, 307 bond, 308 ctkd, 309 identity_address, 310 linger, 311 io, 312 oob, 313 prompt, 314 request, 315 print_keys, 316 keystore_file, 317 device_config, 318 hci_transport, 319 address_or_name, 320): 321 Waiter.instance = Waiter(linger=linger) 322 323 print('<<< connecting to HCI...') 324 async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink): 325 print('<<< connected') 326 327 # Create a device to manage the host 328 device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink) 329 330 # Expose a GATT characteristic that can be used to trigger pairing by 331 # responding with an authentication error when read 332 if mode == 'le': 333 device.le_enabled = True 334 device.add_service( 335 Service( 336 '50DB505C-8AC4-4738-8448-3B1D9CC09CC5', 337 [ 338 Characteristic( 339 '552957FB-CF1F-4A31-9535-E78847E1A714', 340 Characteristic.Properties.READ 341 | Characteristic.Properties.WRITE, 342 Characteristic.READABLE | Characteristic.WRITEABLE, 343 CharacteristicValue( 344 read=read_with_error, write=write_with_error 345 ), 346 ) 347 ], 348 ) 349 ) 350 351 # Select LE or Classic 352 if mode == 'classic': 353 device.classic_enabled = True 354 device.classic_smp_enabled = ctkd 355 356 # Get things going 357 await device.power_on() 358 359 # Set a custom keystore if specified on the command line 360 if keystore_file: 361 device.keystore = JsonKeyStore.from_device(device, filename=keystore_file) 362 363 # Print the existing keys before pairing 364 if print_keys and device.keystore: 365 print(color('@@@-----------------------------------', 'blue')) 366 print(color('@@@ Pairing Keys:', 'blue')) 367 await device.keystore.print(prefix=color('@@@ ', 'blue')) 368 print(color('@@@-----------------------------------', 'blue')) 369 370 # Create an OOB context if needed 371 if oob: 372 our_oob_context = OobContext() 373 shared_data = ( 374 None 375 if oob == '-' 376 else OobData.from_ad(AdvertisingData.from_bytes(bytes.fromhex(oob))) 377 ) 378 legacy_context = OobLegacyContext() 379 oob_contexts = PairingConfig.OobConfig( 380 our_context=our_oob_context, 381 peer_data=shared_data, 382 legacy_context=legacy_context, 383 ) 384 oob_data = OobData( 385 address=device.random_address, 386 shared_data=shared_data, 387 legacy_context=legacy_context, 388 ) 389 print(color('@@@-----------------------------------', 'yellow')) 390 print(color('@@@ OOB Data:', 'yellow')) 391 print(color(f'@@@ {our_oob_context.share()}', 'yellow')) 392 print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow')) 393 print(color(f'@@@ HEX: ({bytes(oob_data.to_ad()).hex()})', 'yellow')) 394 print(color('@@@-----------------------------------', 'yellow')) 395 else: 396 oob_contexts = None 397 398 # Set up a pairing config factory 399 if identity_address == 'public': 400 identity_address_type = PairingConfig.AddressType.PUBLIC 401 elif identity_address == 'random': 402 identity_address_type = PairingConfig.AddressType.RANDOM 403 else: 404 identity_address_type = None 405 device.pairing_config_factory = lambda connection: PairingConfig( 406 sc=sc, 407 mitm=mitm, 408 bonding=bond, 409 oob=oob_contexts, 410 identity_address_type=identity_address_type, 411 delegate=Delegate(mode, connection, io, prompt), 412 ) 413 414 # Connect to a peer or wait for a connection 415 device.on('connection', lambda connection: on_connection(connection, request)) 416 if address_or_name is not None: 417 print(color(f'=== Connecting to {address_or_name}...', 'green')) 418 connection = await device.connect( 419 address_or_name, 420 transport=BT_LE_TRANSPORT if mode == 'le' else BT_BR_EDR_TRANSPORT, 421 ) 422 423 if not request: 424 try: 425 if mode == 'le': 426 await connection.pair() 427 else: 428 await connection.authenticate() 429 except ProtocolError as error: 430 print(color(f'Pairing failed: {error}', 'red')) 431 432 else: 433 if mode == 'le': 434 # Advertise so that peers can find us and connect 435 await device.start_advertising(auto_restart=True) 436 else: 437 # Become discoverable and connectable 438 await device.set_discoverable(True) 439 await device.set_connectable(True) 440 441 # Run until the user asks to exit 442 await Waiter.instance.wait_until_terminated() 443 444 445# ----------------------------------------------------------------------------- 446class LogHandler(logging.Handler): 447 def __init__(self): 448 super().__init__() 449 self.setFormatter(logging.Formatter('%(levelname)s:%(name)s:%(message)s')) 450 451 def emit(self, record): 452 message = self.format(record) 453 print(message) 454 455 456# ----------------------------------------------------------------------------- 457@click.command() 458@click.option( 459 '--mode', type=click.Choice(['le', 'classic']), default='le', show_default=True 460) 461@click.option( 462 '--sc', 463 type=bool, 464 default=True, 465 help='Use the Secure Connections protocol', 466 show_default=True, 467) 468@click.option( 469 '--mitm', type=bool, default=True, help='Request MITM protection', show_default=True 470) 471@click.option( 472 '--bond', type=bool, default=True, help='Enable bonding', show_default=True 473) 474@click.option( 475 '--ctkd', 476 type=bool, 477 default=True, 478 help='Enable CTKD', 479 show_default=True, 480) 481@click.option( 482 '--identity-address', 483 type=click.Choice(['random', 'public']), 484) 485@click.option('--linger', default=False, is_flag=True, help='Linger after pairing') 486@click.option( 487 '--io', 488 type=click.Choice( 489 ['keyboard', 'display', 'display+keyboard', 'display+yes/no', 'none'] 490 ), 491 default='display+keyboard', 492 show_default=True, 493) 494@click.option( 495 '--oob', 496 metavar='<oob-data-hex>', 497 help=( 498 'Use OOB pairing with this data from the peer ' 499 '(use "-" to enable OOB without peer data)' 500 ), 501) 502@click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request') 503@click.option( 504 '--request', is_flag=True, help='Request that the connecting peer initiate pairing' 505) 506@click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing') 507@click.option( 508 '--keystore-file', 509 metavar='<filename>', 510 help='File in which to store the pairing keys', 511) 512@click.argument('device-config') 513@click.argument('hci_transport') 514@click.argument('address-or-name', required=False) 515def main( 516 mode, 517 sc, 518 mitm, 519 bond, 520 ctkd, 521 identity_address, 522 linger, 523 io, 524 oob, 525 prompt, 526 request, 527 print_keys, 528 keystore_file, 529 device_config, 530 hci_transport, 531 address_or_name, 532): 533 # Setup logging 534 log_handler = LogHandler() 535 root_logger = logging.getLogger() 536 root_logger.addHandler(log_handler) 537 root_logger.setLevel(os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) 538 539 # Pair 540 asyncio.run( 541 pair( 542 mode, 543 sc, 544 mitm, 545 bond, 546 ctkd, 547 identity_address, 548 linger, 549 io, 550 oob, 551 prompt, 552 request, 553 print_keys, 554 keystore_file, 555 device_config, 556 hci_transport, 557 address_or_name, 558 ) 559 ) 560 561 562# ----------------------------------------------------------------------------- 563if __name__ == '__main__': 564 main() 565