1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import annotations
16import asyncio
17import contextlib
18import grpc
19import logging
20
21from . import utils
22from .config import Config
23from bumble import hci
24from bumble.core import (
25    BT_BR_EDR_TRANSPORT,
26    BT_LE_TRANSPORT,
27    BT_PERIPHERAL_ROLE,
28    ProtocolError,
29)
30from bumble.device import Connection as BumbleConnection, Device
31from bumble.hci import HCI_Error
32from bumble.utils import EventWatcher
33from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
34from google.protobuf import any_pb2  # pytype: disable=pyi-error
35from google.protobuf import empty_pb2  # pytype: disable=pyi-error
36from google.protobuf import wrappers_pb2  # pytype: disable=pyi-error
37from pandora.host_pb2 import Connection
38from pandora.security_grpc_aio import SecurityServicer, SecurityStorageServicer
39from pandora.security_pb2 import (
40    LE_LEVEL1,
41    LE_LEVEL2,
42    LE_LEVEL3,
43    LE_LEVEL4,
44    LEVEL0,
45    LEVEL1,
46    LEVEL2,
47    LEVEL3,
48    LEVEL4,
49    DeleteBondRequest,
50    IsBondedRequest,
51    LESecurityLevel,
52    PairingEvent,
53    PairingEventAnswer,
54    SecureRequest,
55    SecureResponse,
56    SecurityLevel,
57    WaitSecurityRequest,
58    WaitSecurityResponse,
59)
60from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, Optional, Union
61
62
63class PairingDelegate(BasePairingDelegate):
64    def __init__(
65        self,
66        connection: BumbleConnection,
67        service: "SecurityService",
68        io_capability: BasePairingDelegate.IoCapability = BasePairingDelegate.NO_OUTPUT_NO_INPUT,
69        local_initiator_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
70        local_responder_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
71    ) -> None:
72        self.log = utils.BumbleServerLoggerAdapter(
73            logging.getLogger(),
74            {'service_name': 'Security', 'device': connection.device},
75        )
76        self.connection = connection
77        self.service = service
78        super().__init__(
79            io_capability,
80            local_initiator_key_distribution,
81            local_responder_key_distribution,
82        )
83
84    async def accept(self) -> bool:
85        return True
86
87    def add_origin(self, ev: PairingEvent) -> PairingEvent:
88        if not self.connection.is_incomplete:
89            assert ev.connection
90            ev.connection.CopyFrom(
91                Connection(
92                    cookie=any_pb2.Any(value=self.connection.handle.to_bytes(4, 'big'))
93                )
94            )
95        else:
96            # In BR/EDR, connection may not be complete,
97            # use address instead
98            assert self.connection.transport == BT_BR_EDR_TRANSPORT
99            ev.address = bytes(reversed(bytes(self.connection.peer_address)))
100
101        return ev
102
103    async def confirm(self, auto: bool = False) -> bool:
104        self.log.debug(
105            f"Pairing event: `just_works` (io_capability: {self.io_capability})"
106        )
107
108        if self.service.event_queue is None or self.service.event_answer is None:
109            return True
110
111        event = self.add_origin(PairingEvent(just_works=empty_pb2.Empty()))
112        self.service.event_queue.put_nowait(event)
113        answer = await anext(self.service.event_answer)  # type: ignore
114        assert answer.event == event
115        assert answer.answer_variant() == 'confirm' and answer.confirm is not None
116        return answer.confirm
117
118    async def compare_numbers(self, number: int, digits: int = 6) -> bool:
119        self.log.debug(
120            f"Pairing event: `numeric_comparison` (io_capability: {self.io_capability})"
121        )
122
123        if self.service.event_queue is None or self.service.event_answer is None:
124            raise RuntimeError('security: unhandled number comparison request')
125
126        event = self.add_origin(PairingEvent(numeric_comparison=number))
127        self.service.event_queue.put_nowait(event)
128        answer = await anext(self.service.event_answer)  # type: ignore
129        assert answer.event == event
130        assert answer.answer_variant() == 'confirm' and answer.confirm is not None
131        return answer.confirm
132
133    async def get_number(self) -> Optional[int]:
134        self.log.debug(
135            f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})"
136        )
137
138        if self.service.event_queue is None or self.service.event_answer is None:
139            raise RuntimeError('security: unhandled number request')
140
141        event = self.add_origin(PairingEvent(passkey_entry_request=empty_pb2.Empty()))
142        self.service.event_queue.put_nowait(event)
143        answer = await anext(self.service.event_answer)  # type: ignore
144        assert answer.event == event
145        if answer.answer_variant() is None:
146            return None
147        assert answer.answer_variant() == 'passkey'
148        return answer.passkey
149
150    async def get_string(self, max_length: int) -> Optional[str]:
151        self.log.debug(
152            f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})"
153        )
154
155        if self.service.event_queue is None or self.service.event_answer is None:
156            raise RuntimeError('security: unhandled pin_code request')
157
158        event = self.add_origin(PairingEvent(pin_code_request=empty_pb2.Empty()))
159        self.service.event_queue.put_nowait(event)
160        answer = await anext(self.service.event_answer)  # type: ignore
161        assert answer.event == event
162        if answer.answer_variant() is None:
163            return None
164        assert answer.answer_variant() == 'pin'
165
166        if answer.pin is None:
167            return None
168
169        pin = answer.pin.decode('utf-8')
170        if not pin or len(pin) > max_length:
171            raise ValueError(f'Pin must be utf-8 encoded up to {max_length} bytes')
172
173        return pin
174
175    async def display_number(self, number: int, digits: int = 6) -> None:
176        if (
177            self.connection.transport == BT_BR_EDR_TRANSPORT
178            and self.io_capability == BasePairingDelegate.DISPLAY_OUTPUT_ONLY
179        ):
180            return
181
182        self.log.debug(
183            f"Pairing event: `passkey_entry_notification` (io_capability: {self.io_capability})"
184        )
185
186        if self.service.event_queue is None:
187            raise RuntimeError('security: unhandled number display request')
188
189        event = self.add_origin(PairingEvent(passkey_entry_notification=number))
190        self.service.event_queue.put_nowait(event)
191
192
193BR_LEVEL_REACHED: Dict[SecurityLevel, Callable[[BumbleConnection], bool]] = {
194    LEVEL0: lambda connection: True,
195    LEVEL1: lambda connection: connection.encryption == 0 or connection.authenticated,
196    LEVEL2: lambda connection: connection.encryption != 0 and connection.authenticated,
197    LEVEL3: lambda connection: connection.encryption != 0
198    and connection.authenticated
199    and connection.link_key_type
200    in (
201        hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE,
202        hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
203    ),
204    LEVEL4: lambda connection: connection.encryption
205    == hci.HCI_Encryption_Change_Event.AES_CCM
206    and connection.authenticated
207    and connection.link_key_type
208    == hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
209}
210
211LE_LEVEL_REACHED: Dict[LESecurityLevel, Callable[[BumbleConnection], bool]] = {
212    LE_LEVEL1: lambda connection: True,
213    LE_LEVEL2: lambda connection: connection.encryption != 0,
214    LE_LEVEL3: lambda connection: connection.encryption != 0
215    and connection.authenticated,
216    LE_LEVEL4: lambda connection: connection.encryption != 0
217    and connection.authenticated
218    and connection.sc,
219}
220
221
222class SecurityService(SecurityServicer):
223    def __init__(self, device: Device, config: Config) -> None:
224        self.log = utils.BumbleServerLoggerAdapter(
225            logging.getLogger(), {'service_name': 'Security', 'device': device}
226        )
227        self.event_queue: Optional[asyncio.Queue[PairingEvent]] = None
228        self.event_answer: Optional[AsyncIterator[PairingEventAnswer]] = None
229        self.device = device
230        self.config = config
231
232        def pairing_config_factory(connection: BumbleConnection) -> PairingConfig:
233            return PairingConfig(
234                sc=config.pairing_sc_enable,
235                mitm=config.pairing_mitm_enable,
236                bonding=config.pairing_bonding_enable,
237                identity_address_type=(
238                    PairingConfig.AddressType.PUBLIC
239                    if connection.self_address.is_public
240                    else config.identity_address_type
241                ),
242                delegate=PairingDelegate(
243                    connection,
244                    self,
245                    io_capability=config.io_capability,
246                    local_initiator_key_distribution=config.smp_local_initiator_key_distribution,
247                    local_responder_key_distribution=config.smp_local_responder_key_distribution,
248                ),
249            )
250
251        self.device.pairing_config_factory = pairing_config_factory
252
253    @utils.rpc
254    async def OnPairing(
255        self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext
256    ) -> AsyncGenerator[PairingEvent, None]:
257        self.log.debug('OnPairing')
258
259        if self.event_queue is not None:
260            raise RuntimeError('already streaming pairing events')
261
262        if len(self.device.connections):
263            raise RuntimeError(
264                'the `OnPairing` method shall be initiated before establishing any connections.'
265            )
266
267        self.event_queue = asyncio.Queue()
268        self.event_answer = request
269
270        try:
271            while event := await self.event_queue.get():
272                yield event
273
274        finally:
275            self.event_queue = None
276            self.event_answer = None
277
278    @utils.rpc
279    async def Secure(
280        self, request: SecureRequest, context: grpc.ServicerContext
281    ) -> SecureResponse:
282        connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
283        self.log.debug(f"Secure: {connection_handle}")
284
285        connection = self.device.lookup_connection(connection_handle)
286        assert connection
287
288        oneof = request.WhichOneof('level')
289        level = getattr(request, oneof)
290        assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[
291            connection.transport
292        ] == oneof
293
294        # security level already reached
295        if self.reached_security_level(connection, level):
296            return SecureResponse(success=empty_pb2.Empty())
297
298        # trigger pairing if needed
299        if self.need_pairing(connection, level):
300            try:
301                self.log.debug('Pair...')
302
303                security_result = asyncio.get_running_loop().create_future()
304
305                with contextlib.closing(EventWatcher()) as watcher:
306
307                    @watcher.on(connection, 'pairing')
308                    def on_pairing(*_: Any) -> None:
309                        security_result.set_result('success')
310
311                    @watcher.on(connection, 'pairing_failure')
312                    def on_pairing_failure(*_: Any) -> None:
313                        security_result.set_result('pairing_failure')
314
315                    @watcher.on(connection, 'disconnection')
316                    def on_disconnection(*_: Any) -> None:
317                        security_result.set_result('connection_died')
318
319                    if (
320                        connection.transport == BT_LE_TRANSPORT
321                        and connection.role == BT_PERIPHERAL_ROLE
322                    ):
323                        connection.request_pairing()
324                    else:
325                        await connection.pair()
326
327                    result = await security_result
328
329                self.log.debug(f'Pairing session complete, status={result}')
330                if result != 'success':
331                    return SecureResponse(**{result: empty_pb2.Empty()})
332            except asyncio.CancelledError:
333                self.log.warning("Connection died during encryption")
334                return SecureResponse(connection_died=empty_pb2.Empty())
335            except (HCI_Error, ProtocolError) as e:
336                self.log.warning(f"Pairing failure: {e}")
337                return SecureResponse(pairing_failure=empty_pb2.Empty())
338
339        # trigger authentication if needed
340        if self.need_authentication(connection, level):
341            try:
342                self.log.debug('Authenticate...')
343                await connection.authenticate()
344                self.log.debug('Authenticated')
345            except asyncio.CancelledError:
346                self.log.warning("Connection died during authentication")
347                return SecureResponse(connection_died=empty_pb2.Empty())
348            except (HCI_Error, ProtocolError) as e:
349                self.log.warning(f"Authentication failure: {e}")
350                return SecureResponse(authentication_failure=empty_pb2.Empty())
351
352        # trigger encryption if needed
353        if self.need_encryption(connection, level):
354            try:
355                self.log.debug('Encrypt...')
356                await connection.encrypt()
357                self.log.debug('Encrypted')
358            except asyncio.CancelledError:
359                self.log.warning("Connection died during encryption")
360                return SecureResponse(connection_died=empty_pb2.Empty())
361            except (HCI_Error, ProtocolError) as e:
362                self.log.warning(f"Encryption failure: {e}")
363                return SecureResponse(encryption_failure=empty_pb2.Empty())
364
365        # security level has been reached ?
366        if self.reached_security_level(connection, level):
367            return SecureResponse(success=empty_pb2.Empty())
368        return SecureResponse(not_reached=empty_pb2.Empty())
369
370    @utils.rpc
371    async def WaitSecurity(
372        self, request: WaitSecurityRequest, context: grpc.ServicerContext
373    ) -> WaitSecurityResponse:
374        connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
375        self.log.debug(f"WaitSecurity: {connection_handle}")
376
377        connection = self.device.lookup_connection(connection_handle)
378        assert connection
379
380        assert request.level
381        level = request.level
382        assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[
383            connection.transport
384        ] == request.level_variant()
385
386        wait_for_security: asyncio.Future[str] = (
387            asyncio.get_running_loop().create_future()
388        )
389        authenticate_task: Optional[asyncio.Future[None]] = None
390        pair_task: Optional[asyncio.Future[None]] = None
391
392        async def authenticate() -> None:
393            assert connection
394            if (encryption := connection.encryption) != 0:
395                self.log.debug('Disable encryption...')
396                try:
397                    await connection.encrypt(enable=False)
398                except:
399                    pass
400                self.log.debug('Disable encryption: done')
401
402            self.log.debug('Authenticate...')
403            await connection.authenticate()
404            self.log.debug('Authenticate: done')
405
406            if encryption != 0 and connection.encryption != encryption:
407                self.log.debug('Re-enable encryption...')
408                await connection.encrypt()
409                self.log.debug('Re-enable encryption: done')
410
411        def set_failure(name: str) -> Callable[..., None]:
412            def wrapper(*args: Any) -> None:
413                self.log.debug(f'Wait for security: error `{name}`: {args}')
414                wait_for_security.set_result(name)
415
416            return wrapper
417
418        def try_set_success(*_: Any) -> None:
419            assert connection
420            if self.reached_security_level(connection, level):
421                self.log.debug('Wait for security: done')
422                wait_for_security.set_result('success')
423
424        def on_encryption_change(*_: Any) -> None:
425            assert connection
426            if self.reached_security_level(connection, level):
427                self.log.debug('Wait for security: done')
428                wait_for_security.set_result('success')
429            elif (
430                connection.transport == BT_BR_EDR_TRANSPORT
431                and self.need_authentication(connection, level)
432            ):
433                nonlocal authenticate_task
434                if authenticate_task is None:
435                    authenticate_task = asyncio.create_task(authenticate())
436
437        def pair(*_: Any) -> None:
438            if self.need_pairing(connection, level):
439                pair_task = asyncio.create_task(connection.pair())
440
441        listeners: Dict[str, Callable[..., None]] = {
442            'disconnection': set_failure('connection_died'),
443            'pairing_failure': set_failure('pairing_failure'),
444            'connection_authentication_failure': set_failure('authentication_failure'),
445            'connection_encryption_failure': set_failure('encryption_failure'),
446            'pairing': try_set_success,
447            'connection_authentication': try_set_success,
448            'connection_encryption_change': on_encryption_change,
449            'classic_pairing': try_set_success,
450            'classic_pairing_failure': set_failure('pairing_failure'),
451            'security_request': pair,
452        }
453
454        with contextlib.closing(EventWatcher()) as watcher:
455            # register event handlers
456            for event, listener in listeners.items():
457                watcher.on(connection, event, listener)
458
459            # security level already reached
460            if self.reached_security_level(connection, level):
461                return WaitSecurityResponse(success=empty_pb2.Empty())
462
463            self.log.debug('Wait for security...')
464            kwargs = {}
465            kwargs[await wait_for_security] = empty_pb2.Empty()
466
467        # wait for `authenticate` to finish if any
468        if authenticate_task is not None:
469            self.log.debug('Wait for authentication...')
470            try:
471                await authenticate_task  # type: ignore
472            except:
473                pass
474            self.log.debug('Authenticated')
475
476        # wait for `pair` to finish if any
477        if pair_task is not None:
478            self.log.debug('Wait for authentication...')
479            try:
480                await pair_task  # type: ignore
481            except:
482                pass
483            self.log.debug('paired')
484
485        return WaitSecurityResponse(**kwargs)
486
487    def reached_security_level(
488        self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel]
489    ) -> bool:
490        self.log.debug(
491            str(
492                {
493                    'level': level,
494                    'encryption': connection.encryption,
495                    'authenticated': connection.authenticated,
496                    'sc': connection.sc,
497                    'link_key_type': connection.link_key_type,
498                }
499            )
500        )
501
502        if isinstance(level, LESecurityLevel):
503            return LE_LEVEL_REACHED[level](connection)
504
505        return BR_LEVEL_REACHED[level](connection)
506
507    def need_pairing(self, connection: BumbleConnection, level: int) -> bool:
508        if connection.transport == BT_LE_TRANSPORT:
509            return level >= LE_LEVEL3 and not connection.authenticated
510        return False
511
512    def need_authentication(self, connection: BumbleConnection, level: int) -> bool:
513        if connection.transport == BT_LE_TRANSPORT:
514            return False
515        if level == LEVEL2 and connection.encryption != 0:
516            return not connection.authenticated
517        return level >= LEVEL2 and not connection.authenticated
518
519    def need_encryption(self, connection: BumbleConnection, level: int) -> bool:
520        # TODO(abel): need to support MITM
521        if connection.transport == BT_LE_TRANSPORT:
522            return level == LE_LEVEL2 and not connection.encryption
523        return level >= LEVEL2 and not connection.encryption
524
525
526class SecurityStorageService(SecurityStorageServicer):
527    def __init__(self, device: Device, config: Config) -> None:
528        self.log = utils.BumbleServerLoggerAdapter(
529            logging.getLogger(), {'service_name': 'SecurityStorage', 'device': device}
530        )
531        self.device = device
532        self.config = config
533
534    @utils.rpc
535    async def IsBonded(
536        self, request: IsBondedRequest, context: grpc.ServicerContext
537    ) -> wrappers_pb2.BoolValue:
538        address = utils.address_from_request(request, request.WhichOneof("address"))
539        self.log.debug(f"IsBonded: {address}")
540
541        if self.device.keystore is not None:
542            is_bonded = await self.device.keystore.get(str(address)) is not None
543        else:
544            is_bonded = False
545
546        return wrappers_pb2.BoolValue(value=is_bonded)
547
548    @utils.rpc
549    async def DeleteBond(
550        self, request: DeleteBondRequest, context: grpc.ServicerContext
551    ) -> empty_pb2.Empty:
552        address = utils.address_from_request(request, request.WhichOneof("address"))
553        self.log.debug(f"DeleteBond: {address}")
554
555        if self.device.keystore is not None:
556            with contextlib.suppress(KeyError):
557                await self.device.keystore.delete(str(address))
558
559        return empty_pb2.Empty()
560