1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import annotations 16import asyncio 17import contextlib 18import grpc 19import logging 20 21from . import utils 22from .config import Config 23from bumble import hci 24from bumble.core import ( 25 BT_BR_EDR_TRANSPORT, 26 BT_LE_TRANSPORT, 27 BT_PERIPHERAL_ROLE, 28 ProtocolError, 29) 30from bumble.device import Connection as BumbleConnection, Device 31from bumble.hci import HCI_Error 32from bumble.utils import EventWatcher 33from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate 34from google.protobuf import any_pb2 # pytype: disable=pyi-error 35from google.protobuf import empty_pb2 # pytype: disable=pyi-error 36from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error 37from pandora.host_pb2 import Connection 38from pandora.security_grpc_aio import SecurityServicer, SecurityStorageServicer 39from pandora.security_pb2 import ( 40 LE_LEVEL1, 41 LE_LEVEL2, 42 LE_LEVEL3, 43 LE_LEVEL4, 44 LEVEL0, 45 LEVEL1, 46 LEVEL2, 47 LEVEL3, 48 LEVEL4, 49 DeleteBondRequest, 50 IsBondedRequest, 51 LESecurityLevel, 52 PairingEvent, 53 PairingEventAnswer, 54 SecureRequest, 55 SecureResponse, 56 SecurityLevel, 57 WaitSecurityRequest, 58 WaitSecurityResponse, 59) 60from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, Optional, Union 61 62 63class PairingDelegate(BasePairingDelegate): 64 def __init__( 65 self, 66 connection: BumbleConnection, 67 service: "SecurityService", 68 io_capability: BasePairingDelegate.IoCapability = BasePairingDelegate.NO_OUTPUT_NO_INPUT, 69 local_initiator_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION, 70 local_responder_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION, 71 ) -> None: 72 self.log = utils.BumbleServerLoggerAdapter( 73 logging.getLogger(), 74 {'service_name': 'Security', 'device': connection.device}, 75 ) 76 self.connection = connection 77 self.service = service 78 super().__init__( 79 io_capability, 80 local_initiator_key_distribution, 81 local_responder_key_distribution, 82 ) 83 84 async def accept(self) -> bool: 85 return True 86 87 def add_origin(self, ev: PairingEvent) -> PairingEvent: 88 if not self.connection.is_incomplete: 89 assert ev.connection 90 ev.connection.CopyFrom( 91 Connection( 92 cookie=any_pb2.Any(value=self.connection.handle.to_bytes(4, 'big')) 93 ) 94 ) 95 else: 96 # In BR/EDR, connection may not be complete, 97 # use address instead 98 assert self.connection.transport == BT_BR_EDR_TRANSPORT 99 ev.address = bytes(reversed(bytes(self.connection.peer_address))) 100 101 return ev 102 103 async def confirm(self, auto: bool = False) -> bool: 104 self.log.debug( 105 f"Pairing event: `just_works` (io_capability: {self.io_capability})" 106 ) 107 108 if self.service.event_queue is None or self.service.event_answer is None: 109 return True 110 111 event = self.add_origin(PairingEvent(just_works=empty_pb2.Empty())) 112 self.service.event_queue.put_nowait(event) 113 answer = await anext(self.service.event_answer) # type: ignore 114 assert answer.event == event 115 assert answer.answer_variant() == 'confirm' and answer.confirm is not None 116 return answer.confirm 117 118 async def compare_numbers(self, number: int, digits: int = 6) -> bool: 119 self.log.debug( 120 f"Pairing event: `numeric_comparison` (io_capability: {self.io_capability})" 121 ) 122 123 if self.service.event_queue is None or self.service.event_answer is None: 124 raise RuntimeError('security: unhandled number comparison request') 125 126 event = self.add_origin(PairingEvent(numeric_comparison=number)) 127 self.service.event_queue.put_nowait(event) 128 answer = await anext(self.service.event_answer) # type: ignore 129 assert answer.event == event 130 assert answer.answer_variant() == 'confirm' and answer.confirm is not None 131 return answer.confirm 132 133 async def get_number(self) -> Optional[int]: 134 self.log.debug( 135 f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})" 136 ) 137 138 if self.service.event_queue is None or self.service.event_answer is None: 139 raise RuntimeError('security: unhandled number request') 140 141 event = self.add_origin(PairingEvent(passkey_entry_request=empty_pb2.Empty())) 142 self.service.event_queue.put_nowait(event) 143 answer = await anext(self.service.event_answer) # type: ignore 144 assert answer.event == event 145 if answer.answer_variant() is None: 146 return None 147 assert answer.answer_variant() == 'passkey' 148 return answer.passkey 149 150 async def get_string(self, max_length: int) -> Optional[str]: 151 self.log.debug( 152 f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})" 153 ) 154 155 if self.service.event_queue is None or self.service.event_answer is None: 156 raise RuntimeError('security: unhandled pin_code request') 157 158 event = self.add_origin(PairingEvent(pin_code_request=empty_pb2.Empty())) 159 self.service.event_queue.put_nowait(event) 160 answer = await anext(self.service.event_answer) # type: ignore 161 assert answer.event == event 162 if answer.answer_variant() is None: 163 return None 164 assert answer.answer_variant() == 'pin' 165 166 if answer.pin is None: 167 return None 168 169 pin = answer.pin.decode('utf-8') 170 if not pin or len(pin) > max_length: 171 raise ValueError(f'Pin must be utf-8 encoded up to {max_length} bytes') 172 173 return pin 174 175 async def display_number(self, number: int, digits: int = 6) -> None: 176 if ( 177 self.connection.transport == BT_BR_EDR_TRANSPORT 178 and self.io_capability == BasePairingDelegate.DISPLAY_OUTPUT_ONLY 179 ): 180 return 181 182 self.log.debug( 183 f"Pairing event: `passkey_entry_notification` (io_capability: {self.io_capability})" 184 ) 185 186 if self.service.event_queue is None: 187 raise RuntimeError('security: unhandled number display request') 188 189 event = self.add_origin(PairingEvent(passkey_entry_notification=number)) 190 self.service.event_queue.put_nowait(event) 191 192 193BR_LEVEL_REACHED: Dict[SecurityLevel, Callable[[BumbleConnection], bool]] = { 194 LEVEL0: lambda connection: True, 195 LEVEL1: lambda connection: connection.encryption == 0 or connection.authenticated, 196 LEVEL2: lambda connection: connection.encryption != 0 and connection.authenticated, 197 LEVEL3: lambda connection: connection.encryption != 0 198 and connection.authenticated 199 and connection.link_key_type 200 in ( 201 hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE, 202 hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE, 203 ), 204 LEVEL4: lambda connection: connection.encryption 205 == hci.HCI_Encryption_Change_Event.AES_CCM 206 and connection.authenticated 207 and connection.link_key_type 208 == hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE, 209} 210 211LE_LEVEL_REACHED: Dict[LESecurityLevel, Callable[[BumbleConnection], bool]] = { 212 LE_LEVEL1: lambda connection: True, 213 LE_LEVEL2: lambda connection: connection.encryption != 0, 214 LE_LEVEL3: lambda connection: connection.encryption != 0 215 and connection.authenticated, 216 LE_LEVEL4: lambda connection: connection.encryption != 0 217 and connection.authenticated 218 and connection.sc, 219} 220 221 222class SecurityService(SecurityServicer): 223 def __init__(self, device: Device, config: Config) -> None: 224 self.log = utils.BumbleServerLoggerAdapter( 225 logging.getLogger(), {'service_name': 'Security', 'device': device} 226 ) 227 self.event_queue: Optional[asyncio.Queue[PairingEvent]] = None 228 self.event_answer: Optional[AsyncIterator[PairingEventAnswer]] = None 229 self.device = device 230 self.config = config 231 232 def pairing_config_factory(connection: BumbleConnection) -> PairingConfig: 233 return PairingConfig( 234 sc=config.pairing_sc_enable, 235 mitm=config.pairing_mitm_enable, 236 bonding=config.pairing_bonding_enable, 237 identity_address_type=( 238 PairingConfig.AddressType.PUBLIC 239 if connection.self_address.is_public 240 else config.identity_address_type 241 ), 242 delegate=PairingDelegate( 243 connection, 244 self, 245 io_capability=config.io_capability, 246 local_initiator_key_distribution=config.smp_local_initiator_key_distribution, 247 local_responder_key_distribution=config.smp_local_responder_key_distribution, 248 ), 249 ) 250 251 self.device.pairing_config_factory = pairing_config_factory 252 253 @utils.rpc 254 async def OnPairing( 255 self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext 256 ) -> AsyncGenerator[PairingEvent, None]: 257 self.log.debug('OnPairing') 258 259 if self.event_queue is not None: 260 raise RuntimeError('already streaming pairing events') 261 262 if len(self.device.connections): 263 raise RuntimeError( 264 'the `OnPairing` method shall be initiated before establishing any connections.' 265 ) 266 267 self.event_queue = asyncio.Queue() 268 self.event_answer = request 269 270 try: 271 while event := await self.event_queue.get(): 272 yield event 273 274 finally: 275 self.event_queue = None 276 self.event_answer = None 277 278 @utils.rpc 279 async def Secure( 280 self, request: SecureRequest, context: grpc.ServicerContext 281 ) -> SecureResponse: 282 connection_handle = int.from_bytes(request.connection.cookie.value, 'big') 283 self.log.debug(f"Secure: {connection_handle}") 284 285 connection = self.device.lookup_connection(connection_handle) 286 assert connection 287 288 oneof = request.WhichOneof('level') 289 level = getattr(request, oneof) 290 assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[ 291 connection.transport 292 ] == oneof 293 294 # security level already reached 295 if self.reached_security_level(connection, level): 296 return SecureResponse(success=empty_pb2.Empty()) 297 298 # trigger pairing if needed 299 if self.need_pairing(connection, level): 300 try: 301 self.log.debug('Pair...') 302 303 security_result = asyncio.get_running_loop().create_future() 304 305 with contextlib.closing(EventWatcher()) as watcher: 306 307 @watcher.on(connection, 'pairing') 308 def on_pairing(*_: Any) -> None: 309 security_result.set_result('success') 310 311 @watcher.on(connection, 'pairing_failure') 312 def on_pairing_failure(*_: Any) -> None: 313 security_result.set_result('pairing_failure') 314 315 @watcher.on(connection, 'disconnection') 316 def on_disconnection(*_: Any) -> None: 317 security_result.set_result('connection_died') 318 319 if ( 320 connection.transport == BT_LE_TRANSPORT 321 and connection.role == BT_PERIPHERAL_ROLE 322 ): 323 connection.request_pairing() 324 else: 325 await connection.pair() 326 327 result = await security_result 328 329 self.log.debug(f'Pairing session complete, status={result}') 330 if result != 'success': 331 return SecureResponse(**{result: empty_pb2.Empty()}) 332 except asyncio.CancelledError: 333 self.log.warning("Connection died during encryption") 334 return SecureResponse(connection_died=empty_pb2.Empty()) 335 except (HCI_Error, ProtocolError) as e: 336 self.log.warning(f"Pairing failure: {e}") 337 return SecureResponse(pairing_failure=empty_pb2.Empty()) 338 339 # trigger authentication if needed 340 if self.need_authentication(connection, level): 341 try: 342 self.log.debug('Authenticate...') 343 await connection.authenticate() 344 self.log.debug('Authenticated') 345 except asyncio.CancelledError: 346 self.log.warning("Connection died during authentication") 347 return SecureResponse(connection_died=empty_pb2.Empty()) 348 except (HCI_Error, ProtocolError) as e: 349 self.log.warning(f"Authentication failure: {e}") 350 return SecureResponse(authentication_failure=empty_pb2.Empty()) 351 352 # trigger encryption if needed 353 if self.need_encryption(connection, level): 354 try: 355 self.log.debug('Encrypt...') 356 await connection.encrypt() 357 self.log.debug('Encrypted') 358 except asyncio.CancelledError: 359 self.log.warning("Connection died during encryption") 360 return SecureResponse(connection_died=empty_pb2.Empty()) 361 except (HCI_Error, ProtocolError) as e: 362 self.log.warning(f"Encryption failure: {e}") 363 return SecureResponse(encryption_failure=empty_pb2.Empty()) 364 365 # security level has been reached ? 366 if self.reached_security_level(connection, level): 367 return SecureResponse(success=empty_pb2.Empty()) 368 return SecureResponse(not_reached=empty_pb2.Empty()) 369 370 @utils.rpc 371 async def WaitSecurity( 372 self, request: WaitSecurityRequest, context: grpc.ServicerContext 373 ) -> WaitSecurityResponse: 374 connection_handle = int.from_bytes(request.connection.cookie.value, 'big') 375 self.log.debug(f"WaitSecurity: {connection_handle}") 376 377 connection = self.device.lookup_connection(connection_handle) 378 assert connection 379 380 assert request.level 381 level = request.level 382 assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[ 383 connection.transport 384 ] == request.level_variant() 385 386 wait_for_security: asyncio.Future[str] = ( 387 asyncio.get_running_loop().create_future() 388 ) 389 authenticate_task: Optional[asyncio.Future[None]] = None 390 pair_task: Optional[asyncio.Future[None]] = None 391 392 async def authenticate() -> None: 393 assert connection 394 if (encryption := connection.encryption) != 0: 395 self.log.debug('Disable encryption...') 396 try: 397 await connection.encrypt(enable=False) 398 except: 399 pass 400 self.log.debug('Disable encryption: done') 401 402 self.log.debug('Authenticate...') 403 await connection.authenticate() 404 self.log.debug('Authenticate: done') 405 406 if encryption != 0 and connection.encryption != encryption: 407 self.log.debug('Re-enable encryption...') 408 await connection.encrypt() 409 self.log.debug('Re-enable encryption: done') 410 411 def set_failure(name: str) -> Callable[..., None]: 412 def wrapper(*args: Any) -> None: 413 self.log.debug(f'Wait for security: error `{name}`: {args}') 414 wait_for_security.set_result(name) 415 416 return wrapper 417 418 def try_set_success(*_: Any) -> None: 419 assert connection 420 if self.reached_security_level(connection, level): 421 self.log.debug('Wait for security: done') 422 wait_for_security.set_result('success') 423 424 def on_encryption_change(*_: Any) -> None: 425 assert connection 426 if self.reached_security_level(connection, level): 427 self.log.debug('Wait for security: done') 428 wait_for_security.set_result('success') 429 elif ( 430 connection.transport == BT_BR_EDR_TRANSPORT 431 and self.need_authentication(connection, level) 432 ): 433 nonlocal authenticate_task 434 if authenticate_task is None: 435 authenticate_task = asyncio.create_task(authenticate()) 436 437 def pair(*_: Any) -> None: 438 if self.need_pairing(connection, level): 439 pair_task = asyncio.create_task(connection.pair()) 440 441 listeners: Dict[str, Callable[..., None]] = { 442 'disconnection': set_failure('connection_died'), 443 'pairing_failure': set_failure('pairing_failure'), 444 'connection_authentication_failure': set_failure('authentication_failure'), 445 'connection_encryption_failure': set_failure('encryption_failure'), 446 'pairing': try_set_success, 447 'connection_authentication': try_set_success, 448 'connection_encryption_change': on_encryption_change, 449 'classic_pairing': try_set_success, 450 'classic_pairing_failure': set_failure('pairing_failure'), 451 'security_request': pair, 452 } 453 454 with contextlib.closing(EventWatcher()) as watcher: 455 # register event handlers 456 for event, listener in listeners.items(): 457 watcher.on(connection, event, listener) 458 459 # security level already reached 460 if self.reached_security_level(connection, level): 461 return WaitSecurityResponse(success=empty_pb2.Empty()) 462 463 self.log.debug('Wait for security...') 464 kwargs = {} 465 kwargs[await wait_for_security] = empty_pb2.Empty() 466 467 # wait for `authenticate` to finish if any 468 if authenticate_task is not None: 469 self.log.debug('Wait for authentication...') 470 try: 471 await authenticate_task # type: ignore 472 except: 473 pass 474 self.log.debug('Authenticated') 475 476 # wait for `pair` to finish if any 477 if pair_task is not None: 478 self.log.debug('Wait for authentication...') 479 try: 480 await pair_task # type: ignore 481 except: 482 pass 483 self.log.debug('paired') 484 485 return WaitSecurityResponse(**kwargs) 486 487 def reached_security_level( 488 self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel] 489 ) -> bool: 490 self.log.debug( 491 str( 492 { 493 'level': level, 494 'encryption': connection.encryption, 495 'authenticated': connection.authenticated, 496 'sc': connection.sc, 497 'link_key_type': connection.link_key_type, 498 } 499 ) 500 ) 501 502 if isinstance(level, LESecurityLevel): 503 return LE_LEVEL_REACHED[level](connection) 504 505 return BR_LEVEL_REACHED[level](connection) 506 507 def need_pairing(self, connection: BumbleConnection, level: int) -> bool: 508 if connection.transport == BT_LE_TRANSPORT: 509 return level >= LE_LEVEL3 and not connection.authenticated 510 return False 511 512 def need_authentication(self, connection: BumbleConnection, level: int) -> bool: 513 if connection.transport == BT_LE_TRANSPORT: 514 return False 515 if level == LEVEL2 and connection.encryption != 0: 516 return not connection.authenticated 517 return level >= LEVEL2 and not connection.authenticated 518 519 def need_encryption(self, connection: BumbleConnection, level: int) -> bool: 520 # TODO(abel): need to support MITM 521 if connection.transport == BT_LE_TRANSPORT: 522 return level == LE_LEVEL2 and not connection.encryption 523 return level >= LEVEL2 and not connection.encryption 524 525 526class SecurityStorageService(SecurityStorageServicer): 527 def __init__(self, device: Device, config: Config) -> None: 528 self.log = utils.BumbleServerLoggerAdapter( 529 logging.getLogger(), {'service_name': 'SecurityStorage', 'device': device} 530 ) 531 self.device = device 532 self.config = config 533 534 @utils.rpc 535 async def IsBonded( 536 self, request: IsBondedRequest, context: grpc.ServicerContext 537 ) -> wrappers_pb2.BoolValue: 538 address = utils.address_from_request(request, request.WhichOneof("address")) 539 self.log.debug(f"IsBonded: {address}") 540 541 if self.device.keystore is not None: 542 is_bonded = await self.device.keystore.get(str(address)) is not None 543 else: 544 is_bonded = False 545 546 return wrappers_pb2.BoolValue(value=is_bonded) 547 548 @utils.rpc 549 async def DeleteBond( 550 self, request: DeleteBondRequest, context: grpc.ServicerContext 551 ) -> empty_pb2.Empty: 552 address = utils.address_from_request(request, request.WhichOneof("address")) 553 self.log.debug(f"DeleteBond: {address}") 554 555 if self.device.keystore is not None: 556 with contextlib.suppress(KeyError): 557 await self.device.keystore.delete(str(address)) 558 559 return empty_pb2.Empty() 560