1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import os
20import logging
21import click
22from prompt_toolkit.shortcuts import PromptSession
23
24from bumble.colors import color
25from bumble.device import Device, Peer
26from bumble.transport import open_transport_or_link
27from bumble.pairing import OobData, PairingDelegate, PairingConfig
28from bumble.smp import OobContext, OobLegacyContext
29from bumble.smp import error_name as smp_error_name
30from bumble.keys import JsonKeyStore
31from bumble.core import (
32    AdvertisingData,
33    ProtocolError,
34    BT_LE_TRANSPORT,
35    BT_BR_EDR_TRANSPORT,
36)
37from bumble.gatt import (
38    GATT_DEVICE_NAME_CHARACTERISTIC,
39    GATT_GENERIC_ACCESS_SERVICE,
40    Service,
41    Characteristic,
42    CharacteristicValue,
43)
44from bumble.att import (
45    ATT_Error,
46    ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
47    ATT_INSUFFICIENT_ENCRYPTION_ERROR,
48)
49from bumble.utils import AsyncRunner
50
51# -----------------------------------------------------------------------------
52# Constants
53# -----------------------------------------------------------------------------
54POST_PAIRING_DELAY = 1
55
56
57# -----------------------------------------------------------------------------
58class Waiter:
59    instance = None
60
61    def __init__(self, linger=False):
62        self.done = asyncio.get_running_loop().create_future()
63        self.linger = linger
64
65    def terminate(self):
66        if not self.linger:
67            self.done.set_result(None)
68
69    async def wait_until_terminated(self):
70        return await self.done
71
72
73# -----------------------------------------------------------------------------
74class Delegate(PairingDelegate):
75    def __init__(self, mode, connection, capability_string, do_prompt):
76        super().__init__(
77            io_capability={
78                'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
79                'display': PairingDelegate.DISPLAY_OUTPUT_ONLY,
80                'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT,
81                'display+yes/no': PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT,
82                'none': PairingDelegate.NO_OUTPUT_NO_INPUT,
83            }[capability_string.lower()]
84        )
85
86        self.mode = mode
87        self.peer = Peer(connection)
88        self.peer_name = None
89        self.do_prompt = do_prompt
90
91    def print(self, message):
92        print(color(message, 'yellow'))
93
94    async def prompt(self, message):
95        # Wait a bit to allow some of the log lines to print before we prompt
96        await asyncio.sleep(1)
97
98        session = PromptSession(message)
99        response = await session.prompt_async()
100        return response.lower().strip()
101
102    async def update_peer_name(self):
103        if self.peer_name is not None:
104            # We already asked the peer
105            return
106
107        # Try to get the peer's name
108        if self.peer:
109            peer_name = await get_peer_name(self.peer, self.mode)
110            self.peer_name = f'{peer_name or ""} [{self.peer.connection.peer_address}]'
111        else:
112            self.peer_name = '[?]'
113
114    async def accept(self):
115        if self.do_prompt:
116            await self.update_peer_name()
117
118            # Prompt for acceptance
119            self.print('###-----------------------------------')
120            self.print(f'### Pairing request from {self.peer_name}')
121            self.print('###-----------------------------------')
122            while True:
123                response = await self.prompt('>>> Accept? ')
124
125                if response == 'yes':
126                    return True
127
128                if response == 'no':
129                    return False
130
131        # Accept silently
132        return True
133
134    async def compare_numbers(self, number, digits):
135        await self.update_peer_name()
136
137        # Prompt for a numeric comparison
138        self.print('###-----------------------------------')
139        self.print(f'### Pairing with {self.peer_name}')
140        self.print('###-----------------------------------')
141        while True:
142            response = await self.prompt(
143                f'>>> Does the other device display {number:0{digits}}? '
144            )
145
146            if response == 'yes':
147                return True
148
149            if response == 'no':
150                return False
151
152    async def get_number(self):
153        await self.update_peer_name()
154
155        # Prompt for a PIN
156        while True:
157            try:
158                self.print('###-----------------------------------')
159                self.print(f'### Pairing with {self.peer_name}')
160                self.print('###-----------------------------------')
161                return int(await self.prompt('>>> Enter PIN: '))
162            except ValueError:
163                pass
164
165    async def display_number(self, number, digits):
166        await self.update_peer_name()
167
168        # Display a PIN code
169        self.print('###-----------------------------------')
170        self.print(f'### Pairing with {self.peer_name}')
171        self.print(f'### PIN: {number:0{digits}}')
172        self.print('###-----------------------------------')
173
174    async def get_string(self, max_length: int):
175        await self.update_peer_name()
176
177        # Prompt a PIN (for legacy pairing in classic)
178        self.print('###-----------------------------------')
179        self.print(f'### Pairing with {self.peer_name}')
180        self.print('###-----------------------------------')
181        count = 0
182        while True:
183            response = await self.prompt('>>> Enter PIN (1-6 chars):')
184            if len(response) == 0:
185                count += 1
186                if count > 3:
187                    self.print('too many tries, stopping the pairing')
188                    return None
189
190                self.print('no PIN was entered, try again')
191                continue
192            return response
193
194
195# -----------------------------------------------------------------------------
196async def get_peer_name(peer, mode):
197    if mode == 'classic':
198        return await peer.request_name()
199
200    # Try to get the peer name from GATT
201    services = await peer.discover_service(GATT_GENERIC_ACCESS_SERVICE)
202    if not services:
203        return None
204
205    values = await peer.read_characteristics_by_uuid(
206        GATT_DEVICE_NAME_CHARACTERISTIC, services[0]
207    )
208    if values:
209        return values[0].decode('utf-8')
210
211    return None
212
213
214# -----------------------------------------------------------------------------
215AUTHENTICATION_ERROR_RETURNED = [False, False]
216
217
218def read_with_error(connection):
219    if not connection.is_encrypted:
220        raise ATT_Error(ATT_INSUFFICIENT_ENCRYPTION_ERROR)
221
222    if AUTHENTICATION_ERROR_RETURNED[0]:
223        return bytes([1])
224
225    AUTHENTICATION_ERROR_RETURNED[0] = True
226    raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR)
227
228
229def write_with_error(connection, _value):
230    if not connection.is_encrypted:
231        raise ATT_Error(ATT_INSUFFICIENT_ENCRYPTION_ERROR)
232
233    if not AUTHENTICATION_ERROR_RETURNED[1]:
234        AUTHENTICATION_ERROR_RETURNED[1] = True
235        raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR)
236
237
238# -----------------------------------------------------------------------------
239def on_connection(connection, request):
240    print(color(f'<<< Connection: {connection}', 'green'))
241
242    # Listen for pairing events
243    connection.on('pairing_start', on_pairing_start)
244    connection.on('pairing', lambda keys: on_pairing(connection, keys))
245    connection.on(
246        'pairing_failure', lambda reason: on_pairing_failure(connection, reason)
247    )
248
249    # Listen for encryption changes
250    connection.on(
251        'connection_encryption_change',
252        lambda: on_connection_encryption_change(connection),
253    )
254
255    # Request pairing if needed
256    if request:
257        print(color('>>> Requesting pairing', 'green'))
258        connection.request_pairing()
259
260
261# -----------------------------------------------------------------------------
262def on_connection_encryption_change(connection):
263    print(color('@@@-----------------------------------', 'blue'))
264    print(
265        color(
266            f'@@@ Connection is {"" if connection.is_encrypted else "not"}encrypted',
267            'blue',
268        )
269    )
270    print(color('@@@-----------------------------------', 'blue'))
271
272
273# -----------------------------------------------------------------------------
274def on_pairing_start():
275    print(color('***-----------------------------------', 'magenta'))
276    print(color('*** Pairing starting', 'magenta'))
277    print(color('***-----------------------------------', 'magenta'))
278
279
280# -----------------------------------------------------------------------------
281@AsyncRunner.run_in_task()
282async def on_pairing(connection, keys):
283    print(color('***-----------------------------------', 'cyan'))
284    print(color(f'*** Paired! (peer identity={connection.peer_address})', 'cyan'))
285    keys.print(prefix=color('*** ', 'cyan'))
286    print(color('***-----------------------------------', 'cyan'))
287    await asyncio.sleep(POST_PAIRING_DELAY)
288    await connection.disconnect()
289    Waiter.instance.terminate()
290
291
292# -----------------------------------------------------------------------------
293@AsyncRunner.run_in_task()
294async def on_pairing_failure(connection, reason):
295    print(color('***-----------------------------------', 'red'))
296    print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
297    print(color('***-----------------------------------', 'red'))
298    await connection.disconnect()
299    Waiter.instance.terminate()
300
301
302# -----------------------------------------------------------------------------
303async def pair(
304    mode,
305    sc,
306    mitm,
307    bond,
308    ctkd,
309    identity_address,
310    linger,
311    io,
312    oob,
313    prompt,
314    request,
315    print_keys,
316    keystore_file,
317    device_config,
318    hci_transport,
319    address_or_name,
320):
321    Waiter.instance = Waiter(linger=linger)
322
323    print('<<< connecting to HCI...')
324    async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
325        print('<<< connected')
326
327        # Create a device to manage the host
328        device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
329
330        # Expose a GATT characteristic that can be used to trigger pairing by
331        # responding with an authentication error when read
332        if mode == 'le':
333            device.le_enabled = True
334            device.add_service(
335                Service(
336                    '50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
337                    [
338                        Characteristic(
339                            '552957FB-CF1F-4A31-9535-E78847E1A714',
340                            Characteristic.Properties.READ
341                            | Characteristic.Properties.WRITE,
342                            Characteristic.READABLE | Characteristic.WRITEABLE,
343                            CharacteristicValue(
344                                read=read_with_error, write=write_with_error
345                            ),
346                        )
347                    ],
348                )
349            )
350
351        # Select LE or Classic
352        if mode == 'classic':
353            device.classic_enabled = True
354            device.classic_smp_enabled = ctkd
355
356        # Get things going
357        await device.power_on()
358
359        # Set a custom keystore if specified on the command line
360        if keystore_file:
361            device.keystore = JsonKeyStore.from_device(device, filename=keystore_file)
362
363        # Print the existing keys before pairing
364        if print_keys and device.keystore:
365            print(color('@@@-----------------------------------', 'blue'))
366            print(color('@@@ Pairing Keys:', 'blue'))
367            await device.keystore.print(prefix=color('@@@ ', 'blue'))
368            print(color('@@@-----------------------------------', 'blue'))
369
370        # Create an OOB context if needed
371        if oob:
372            our_oob_context = OobContext()
373            shared_data = (
374                None
375                if oob == '-'
376                else OobData.from_ad(AdvertisingData.from_bytes(bytes.fromhex(oob)))
377            )
378            legacy_context = OobLegacyContext()
379            oob_contexts = PairingConfig.OobConfig(
380                our_context=our_oob_context,
381                peer_data=shared_data,
382                legacy_context=legacy_context,
383            )
384            oob_data = OobData(
385                address=device.random_address,
386                shared_data=shared_data,
387                legacy_context=legacy_context,
388            )
389            print(color('@@@-----------------------------------', 'yellow'))
390            print(color('@@@ OOB Data:', 'yellow'))
391            print(color(f'@@@   {our_oob_context.share()}', 'yellow'))
392            print(color(f'@@@   TK={legacy_context.tk.hex()}', 'yellow'))
393            print(color(f'@@@   HEX: ({bytes(oob_data.to_ad()).hex()})', 'yellow'))
394            print(color('@@@-----------------------------------', 'yellow'))
395        else:
396            oob_contexts = None
397
398        # Set up a pairing config factory
399        if identity_address == 'public':
400            identity_address_type = PairingConfig.AddressType.PUBLIC
401        elif identity_address == 'random':
402            identity_address_type = PairingConfig.AddressType.RANDOM
403        else:
404            identity_address_type = None
405        device.pairing_config_factory = lambda connection: PairingConfig(
406            sc=sc,
407            mitm=mitm,
408            bonding=bond,
409            oob=oob_contexts,
410            identity_address_type=identity_address_type,
411            delegate=Delegate(mode, connection, io, prompt),
412        )
413
414        # Connect to a peer or wait for a connection
415        device.on('connection', lambda connection: on_connection(connection, request))
416        if address_or_name is not None:
417            print(color(f'=== Connecting to {address_or_name}...', 'green'))
418            connection = await device.connect(
419                address_or_name,
420                transport=BT_LE_TRANSPORT if mode == 'le' else BT_BR_EDR_TRANSPORT,
421            )
422
423            if not request:
424                try:
425                    if mode == 'le':
426                        await connection.pair()
427                    else:
428                        await connection.authenticate()
429                except ProtocolError as error:
430                    print(color(f'Pairing failed: {error}', 'red'))
431
432        else:
433            if mode == 'le':
434                # Advertise so that peers can find us and connect
435                await device.start_advertising(auto_restart=True)
436            else:
437                # Become discoverable and connectable
438                await device.set_discoverable(True)
439                await device.set_connectable(True)
440
441        # Run until the user asks to exit
442        await Waiter.instance.wait_until_terminated()
443
444
445# -----------------------------------------------------------------------------
446class LogHandler(logging.Handler):
447    def __init__(self):
448        super().__init__()
449        self.setFormatter(logging.Formatter('%(levelname)s:%(name)s:%(message)s'))
450
451    def emit(self, record):
452        message = self.format(record)
453        print(message)
454
455
456# -----------------------------------------------------------------------------
457@click.command()
458@click.option(
459    '--mode', type=click.Choice(['le', 'classic']), default='le', show_default=True
460)
461@click.option(
462    '--sc',
463    type=bool,
464    default=True,
465    help='Use the Secure Connections protocol',
466    show_default=True,
467)
468@click.option(
469    '--mitm', type=bool, default=True, help='Request MITM protection', show_default=True
470)
471@click.option(
472    '--bond', type=bool, default=True, help='Enable bonding', show_default=True
473)
474@click.option(
475    '--ctkd',
476    type=bool,
477    default=True,
478    help='Enable CTKD',
479    show_default=True,
480)
481@click.option(
482    '--identity-address',
483    type=click.Choice(['random', 'public']),
484)
485@click.option('--linger', default=False, is_flag=True, help='Linger after pairing')
486@click.option(
487    '--io',
488    type=click.Choice(
489        ['keyboard', 'display', 'display+keyboard', 'display+yes/no', 'none']
490    ),
491    default='display+keyboard',
492    show_default=True,
493)
494@click.option(
495    '--oob',
496    metavar='<oob-data-hex>',
497    help=(
498        'Use OOB pairing with this data from the peer '
499        '(use "-" to enable OOB without peer data)'
500    ),
501)
502@click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request')
503@click.option(
504    '--request', is_flag=True, help='Request that the connecting peer initiate pairing'
505)
506@click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing')
507@click.option(
508    '--keystore-file',
509    metavar='<filename>',
510    help='File in which to store the pairing keys',
511)
512@click.argument('device-config')
513@click.argument('hci_transport')
514@click.argument('address-or-name', required=False)
515def main(
516    mode,
517    sc,
518    mitm,
519    bond,
520    ctkd,
521    identity_address,
522    linger,
523    io,
524    oob,
525    prompt,
526    request,
527    print_keys,
528    keystore_file,
529    device_config,
530    hci_transport,
531    address_or_name,
532):
533    # Setup logging
534    log_handler = LogHandler()
535    root_logger = logging.getLogger()
536    root_logger.addHandler(log_handler)
537    root_logger.setLevel(os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
538
539    # Pair
540    asyncio.run(
541        pair(
542            mode,
543            sc,
544            mitm,
545            bond,
546            ctkd,
547            identity_address,
548            linger,
549            io,
550            oob,
551            prompt,
552            request,
553            print_keys,
554            keystore_file,
555            device_config,
556            hci_transport,
557            address_or_name,
558        )
559    )
560
561
562# -----------------------------------------------------------------------------
563if __name__ == '__main__':
564    main()
565