1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import logging
19import asyncio
20from functools import partial
21
22from bumble.core import (
23    BT_PERIPHERAL_ROLE,
24    BT_BR_EDR_TRANSPORT,
25    BT_LE_TRANSPORT,
26    InvalidStateError,
27)
28from bumble.colors import color
29from bumble.hci import (
30    Address,
31    HCI_SUCCESS,
32    HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
33    HCI_CONNECTION_TIMEOUT_ERROR,
34    HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
35    HCI_PAGE_TIMEOUT_ERROR,
36    HCI_Connection_Complete_Event,
37)
38from bumble import controller
39
40from typing import Optional, Set
41
42# -----------------------------------------------------------------------------
43# Logging
44# -----------------------------------------------------------------------------
45logger = logging.getLogger(__name__)
46
47
48# -----------------------------------------------------------------------------
49# Utils
50# -----------------------------------------------------------------------------
51def parse_parameters(params_str):
52    result = {}
53    for param_str in params_str.split(','):
54        if '=' in param_str:
55            key, value = param_str.split('=')
56            result[key] = value
57    return result
58
59
60# -----------------------------------------------------------------------------
61# TODO: add more support for various LL exchanges
62# (see Vol 6, Part B - 2.4 DATA CHANNEL PDU)
63# -----------------------------------------------------------------------------
64class LocalLink:
65    '''
66    Link bus for controllers to communicate with each other
67    '''
68
69    controllers: Set[controller.Controller]
70
71    def __init__(self):
72        self.controllers = set()
73        self.pending_connection = None
74        self.pending_classic_connection = None
75
76    ############################################################
77    # Common utils
78    ############################################################
79
80    def add_controller(self, controller):
81        logger.debug(f'new controller: {controller}')
82        self.controllers.add(controller)
83
84    def remove_controller(self, controller):
85        self.controllers.remove(controller)
86
87    def find_controller(self, address):
88        for controller in self.controllers:
89            if controller.random_address == address:
90                return controller
91        return None
92
93    def find_classic_controller(
94        self, address: Address
95    ) -> Optional[controller.Controller]:
96        for controller in self.controllers:
97            if controller.public_address == address:
98                return controller
99        return None
100
101    def get_pending_connection(self):
102        return self.pending_connection
103
104    ############################################################
105    # LE handlers
106    ############################################################
107
108    def on_address_changed(self, controller):
109        pass
110
111    def send_advertising_data(self, sender_address, data):
112        # Send the advertising data to all controllers, except the sender
113        for controller in self.controllers:
114            if controller.random_address != sender_address:
115                controller.on_link_advertising_data(sender_address, data)
116
117    def send_acl_data(self, sender_controller, destination_address, transport, data):
118        # Send the data to the first controller with a matching address
119        if transport == BT_LE_TRANSPORT:
120            destination_controller = self.find_controller(destination_address)
121            source_address = sender_controller.random_address
122        elif transport == BT_BR_EDR_TRANSPORT:
123            destination_controller = self.find_classic_controller(destination_address)
124            source_address = sender_controller.public_address
125
126        if destination_controller is not None:
127            destination_controller.on_link_acl_data(source_address, transport, data)
128
129    def on_connection_complete(self):
130        # Check that we expect this call
131        if not self.pending_connection:
132            logger.warning('on_connection_complete with no pending connection')
133            return
134
135        central_address, le_create_connection_command = self.pending_connection
136        self.pending_connection = None
137
138        # Find the controller that initiated the connection
139        if not (central_controller := self.find_controller(central_address)):
140            logger.warning('!!! Initiating controller not found')
141            return
142
143        # Connect to the first controller with a matching address
144        if peripheral_controller := self.find_controller(
145            le_create_connection_command.peer_address
146        ):
147            central_controller.on_link_peripheral_connection_complete(
148                le_create_connection_command, HCI_SUCCESS
149            )
150            peripheral_controller.on_link_central_connected(central_address)
151            return
152
153        # No peripheral found
154        central_controller.on_link_peripheral_connection_complete(
155            le_create_connection_command, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR
156        )
157
158    def connect(self, central_address, le_create_connection_command):
159        logger.debug(
160            f'$$$ CONNECTION {central_address} -> '
161            f'{le_create_connection_command.peer_address}'
162        )
163        self.pending_connection = (central_address, le_create_connection_command)
164        asyncio.get_running_loop().call_soon(self.on_connection_complete)
165
166    def on_disconnection_complete(
167        self, central_address, peripheral_address, disconnect_command
168    ):
169        # Find the controller that initiated the disconnection
170        if not (central_controller := self.find_controller(central_address)):
171            logger.warning('!!! Initiating controller not found')
172            return
173
174        # Disconnect from the first controller with a matching address
175        if peripheral_controller := self.find_controller(peripheral_address):
176            peripheral_controller.on_link_central_disconnected(
177                central_address, disconnect_command.reason
178            )
179
180        central_controller.on_link_peripheral_disconnection_complete(
181            disconnect_command, HCI_SUCCESS
182        )
183
184    def disconnect(self, central_address, peripheral_address, disconnect_command):
185        logger.debug(
186            f'$$$ DISCONNECTION {central_address} -> '
187            f'{peripheral_address}: reason = {disconnect_command.reason}'
188        )
189        args = [central_address, peripheral_address, disconnect_command]
190        asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args)
191
192    # pylint: disable=too-many-arguments
193    def on_connection_encrypted(
194        self, central_address, peripheral_address, rand, ediv, ltk
195    ):
196        logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}')
197
198        if central_controller := self.find_controller(central_address):
199            central_controller.on_link_encrypted(peripheral_address, rand, ediv, ltk)
200
201        if peripheral_controller := self.find_controller(peripheral_address):
202            peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk)
203
204    def create_cis(
205        self,
206        central_controller: controller.Controller,
207        peripheral_address: Address,
208        cig_id: int,
209        cis_id: int,
210    ) -> None:
211        logger.debug(
212            f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}'
213        )
214        if peripheral_controller := self.find_controller(peripheral_address):
215            asyncio.get_running_loop().call_soon(
216                peripheral_controller.on_link_cis_request,
217                central_controller.random_address,
218                cig_id,
219                cis_id,
220            )
221
222    def accept_cis(
223        self,
224        peripheral_controller: controller.Controller,
225        central_address: Address,
226        cig_id: int,
227        cis_id: int,
228    ) -> None:
229        logger.debug(
230            f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}'
231        )
232        if central_controller := self.find_controller(central_address):
233            asyncio.get_running_loop().call_soon(
234                central_controller.on_link_cis_established, cig_id, cis_id
235            )
236            asyncio.get_running_loop().call_soon(
237                peripheral_controller.on_link_cis_established, cig_id, cis_id
238            )
239
240    def disconnect_cis(
241        self,
242        initiator_controller: controller.Controller,
243        peer_address: Address,
244        cig_id: int,
245        cis_id: int,
246    ) -> None:
247        logger.debug(
248            f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}'
249        )
250        if peer_controller := self.find_controller(peer_address):
251            asyncio.get_running_loop().call_soon(
252                initiator_controller.on_link_cis_disconnected, cig_id, cis_id
253            )
254            asyncio.get_running_loop().call_soon(
255                peer_controller.on_link_cis_disconnected, cig_id, cis_id
256            )
257
258    ############################################################
259    # Classic handlers
260    ############################################################
261
262    def classic_connect(self, initiator_controller, responder_address):
263        logger.debug(
264            f'[Classic] {initiator_controller.public_address} connects to {responder_address}'
265        )
266        responder_controller = self.find_classic_controller(responder_address)
267        if responder_controller is None:
268            initiator_controller.on_classic_connection_complete(
269                responder_address, HCI_PAGE_TIMEOUT_ERROR
270            )
271            return
272        self.pending_classic_connection = (initiator_controller, responder_controller)
273
274        responder_controller.on_classic_connection_request(
275            initiator_controller.public_address,
276            HCI_Connection_Complete_Event.ACL_LINK_TYPE,
277        )
278
279    def classic_accept_connection(
280        self, responder_controller, initiator_address, responder_role
281    ):
282        logger.debug(
283            f'[Classic] {responder_controller.public_address} accepts to connect {initiator_address}'
284        )
285        initiator_controller = self.find_classic_controller(initiator_address)
286        if initiator_controller is None:
287            responder_controller.on_classic_connection_complete(
288                responder_controller.public_address, HCI_PAGE_TIMEOUT_ERROR
289            )
290            return
291
292        async def task():
293            if responder_role != BT_PERIPHERAL_ROLE:
294                initiator_controller.on_classic_role_change(
295                    responder_controller.public_address, int(not (responder_role))
296                )
297            initiator_controller.on_classic_connection_complete(
298                responder_controller.public_address, HCI_SUCCESS
299            )
300
301        asyncio.create_task(task())
302        responder_controller.on_classic_role_change(
303            initiator_controller.public_address, responder_role
304        )
305        responder_controller.on_classic_connection_complete(
306            initiator_controller.public_address, HCI_SUCCESS
307        )
308        self.pending_classic_connection = None
309
310    def classic_disconnect(self, initiator_controller, responder_address, reason):
311        logger.debug(
312            f'[Classic] {initiator_controller.public_address} disconnects {responder_address}'
313        )
314        responder_controller = self.find_classic_controller(responder_address)
315
316        async def task():
317            initiator_controller.on_classic_disconnected(responder_address, reason)
318
319        asyncio.create_task(task())
320        responder_controller.on_classic_disconnected(
321            initiator_controller.public_address, reason
322        )
323
324    def classic_switch_role(
325        self, initiator_controller, responder_address, initiator_new_role
326    ):
327        responder_controller = self.find_classic_controller(responder_address)
328        if responder_controller is None:
329            return
330
331        async def task():
332            initiator_controller.on_classic_role_change(
333                responder_address, initiator_new_role
334            )
335
336        asyncio.create_task(task())
337        responder_controller.on_classic_role_change(
338            initiator_controller.public_address, int(not (initiator_new_role))
339        )
340
341    def classic_sco_connect(
342        self,
343        initiator_controller: controller.Controller,
344        responder_address: Address,
345        link_type: int,
346    ):
347        logger.debug(
348            f'[Classic] {initiator_controller.public_address} connects SCO to {responder_address}'
349        )
350        responder_controller = self.find_classic_controller(responder_address)
351        # Initiator controller should handle it.
352        assert responder_controller
353
354        responder_controller.on_classic_connection_request(
355            initiator_controller.public_address,
356            link_type,
357        )
358
359    def classic_accept_sco_connection(
360        self,
361        responder_controller: controller.Controller,
362        initiator_address: Address,
363        link_type: int,
364    ):
365        logger.debug(
366            f'[Classic] {responder_controller.public_address} accepts to connect SCO {initiator_address}'
367        )
368        initiator_controller = self.find_classic_controller(initiator_address)
369        if initiator_controller is None:
370            responder_controller.on_classic_sco_connection_complete(
371                responder_controller.public_address,
372                HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
373                link_type,
374            )
375            return
376
377        async def task():
378            initiator_controller.on_classic_sco_connection_complete(
379                responder_controller.public_address, HCI_SUCCESS, link_type
380            )
381
382        asyncio.create_task(task())
383        responder_controller.on_classic_sco_connection_complete(
384            initiator_controller.public_address, HCI_SUCCESS, link_type
385        )
386
387
388# -----------------------------------------------------------------------------
389class RemoteLink:
390    '''
391    A Link implementation that communicates with other virtual controllers via a
392    WebSocket relay
393    '''
394
395    def __init__(self, uri):
396        self.controller = None
397        self.uri = uri
398        self.execution_queue = asyncio.Queue()
399        self.websocket = asyncio.get_running_loop().create_future()
400        self.rpc_result = None
401        self.pending_connection = None
402        self.central_connections = set()  # List of addresses that we have connected to
403        self.peripheral_connections = (
404            set()
405        )  # List of addresses that have connected to us
406
407        # Connect and run asynchronously
408        asyncio.create_task(self.run_connection())
409        asyncio.create_task(self.run_executor_loop())
410
411    def add_controller(self, controller):
412        if self.controller:
413            raise InvalidStateError('controller already set')
414        self.controller = controller
415
416    def remove_controller(self, controller):
417        if self.controller != controller:
418            raise InvalidStateError('controller mismatch')
419        self.controller = None
420
421    def get_pending_connection(self):
422        return self.pending_connection
423
424    def get_pending_classic_connection(self):
425        return self.pending_classic_connection
426
427    async def wait_until_connected(self):
428        await self.websocket
429
430    def execute(self, async_function):
431        self.execution_queue.put_nowait(async_function())
432
433    async def run_executor_loop(self):
434        logger.debug('executor loop starting')
435        while True:
436            item = await self.execution_queue.get()
437            try:
438                await item
439            except Exception as error:
440                logger.warning(
441                    f'{color("!!! Exception in async handler:", "red")} {error}'
442                )
443
444    async def run_connection(self):
445        import websockets  # lazy import
446
447        # Connect to the relay
448        logger.debug(f'connecting to {self.uri}')
449        # pylint: disable-next=no-member
450        websocket = await websockets.connect(self.uri)
451        self.websocket.set_result(websocket)
452        logger.debug(f'connected to {self.uri}')
453
454        while True:
455            message = await websocket.recv()
456            logger.debug(f'received message: {message}')
457            keyword, *payload = message.split(':', 1)
458
459            handler_name = f'on_{keyword}_received'
460            handler = getattr(self, handler_name, None)
461            if handler:
462                await handler(payload[0] if payload else None)
463
464    def close(self):
465        if self.websocket.done():
466            logger.debug('closing websocket')
467            websocket = self.websocket.result()
468            asyncio.create_task(websocket.close())
469
470    async def on_result_received(self, result):
471        if self.rpc_result:
472            self.rpc_result.set_result(result)
473
474    async def on_left_received(self, address):
475        if address in self.central_connections:
476            self.controller.on_link_peripheral_disconnected(Address(address))
477            self.central_connections.remove(address)
478
479        if address in self.peripheral_connections:
480            self.controller.on_link_central_disconnected(
481                address, HCI_CONNECTION_TIMEOUT_ERROR
482            )
483            self.peripheral_connections.remove(address)
484
485    async def on_unreachable_received(self, target):
486        await self.on_left_received(target)
487
488    async def on_message_received(self, message):
489        sender, *payload = message.split('/', 1)
490        if payload:
491            keyword, *payload = payload[0].split(':', 1)
492            handler_name = f'on_{keyword}_message_received'
493            handler = getattr(self, handler_name, None)
494            if handler:
495                await handler(sender, payload[0] if payload else None)
496
497    async def on_advertisement_message_received(self, sender, advertisement):
498        try:
499            self.controller.on_link_advertising_data(
500                Address(sender), bytes.fromhex(advertisement)
501            )
502        except Exception:
503            logger.exception('exception')
504
505    async def on_acl_message_received(self, sender, acl_data):
506        try:
507            self.controller.on_link_acl_data(Address(sender), bytes.fromhex(acl_data))
508        except Exception:
509            logger.exception('exception')
510
511    async def on_connect_message_received(self, sender, _):
512        # Remember the connection
513        self.peripheral_connections.add(sender)
514
515        # Notify the controller
516        logger.debug(f'connection from central {sender}')
517        self.controller.on_link_central_connected(Address(sender))
518
519        # Accept the connection by responding to it
520        await self.send_targeted_message(sender, 'connected')
521
522    async def on_connected_message_received(self, sender, _):
523        if not self.pending_connection:
524            logger.warning('received a connection ack, but no connection is pending')
525            return
526
527        # Remember the connection
528        self.central_connections.add(sender)
529
530        # Notify the controller
531        logger.debug(f'connected to peripheral {self.pending_connection.peer_address}')
532        self.controller.on_link_peripheral_connection_complete(
533            self.pending_connection, HCI_SUCCESS
534        )
535
536    async def on_disconnect_message_received(self, sender, message):
537        # Notify the controller
538        params = parse_parameters(message)
539        reason = int(params.get('reason', str(HCI_CONNECTION_TIMEOUT_ERROR)))
540        self.controller.on_link_central_disconnected(Address(sender), reason)
541
542        # Forget the connection
543        if sender in self.peripheral_connections:
544            self.peripheral_connections.remove(sender)
545
546    async def on_encrypted_message_received(self, sender, _):
547        # TODO parse params to get real args
548        self.controller.on_link_encrypted(Address(sender), bytes(8), 0, bytes(16))
549
550    async def send_rpc_command(self, command):
551        # Ensure we have a connection
552        websocket = await self.websocket
553
554        # Create a future value to hold the eventual result
555        assert self.rpc_result is None
556        self.rpc_result = asyncio.get_running_loop().create_future()
557
558        # Send the command
559        await websocket.send(command)
560
561        # Wait for the result
562        rpc_result = await self.rpc_result
563        self.rpc_result = None
564        logger.debug(f'rpc_result: {rpc_result}')
565
566        # TODO: parse the result
567
568    async def send_targeted_message(self, target, message):
569        # Ensure we have a connection
570        websocket = await self.websocket
571
572        # Send the message
573        await websocket.send(f'@{target} {message}')
574
575    async def notify_address_changed(self):
576        await self.send_rpc_command(f'/set-address {self.controller.random_address}')
577
578    def on_address_changed(self, controller):
579        logger.info(f'address changed for {controller}: {controller.random_address}')
580
581        # Notify the relay of the change
582        self.execute(self.notify_address_changed)
583
584    async def send_advertising_data_to_relay(self, data):
585        await self.send_targeted_message('*', f'advertisement:{data.hex()}')
586
587    def send_advertising_data(self, _, data):
588        self.execute(partial(self.send_advertising_data_to_relay, data))
589
590    async def send_acl_data_to_relay(self, peer_address, data):
591        await self.send_targeted_message(peer_address, f'acl:{data.hex()}')
592
593    def send_acl_data(self, _, peer_address, _transport, data):
594        # TODO: handle different transport
595        self.execute(partial(self.send_acl_data_to_relay, peer_address, data))
596
597    async def send_connection_request_to_relay(self, peer_address):
598        await self.send_targeted_message(peer_address, 'connect')
599
600    def connect(self, _, le_create_connection_command):
601        if self.pending_connection:
602            logger.warning('connection already pending')
603            return
604        self.pending_connection = le_create_connection_command
605        self.execute(
606            partial(
607                self.send_connection_request_to_relay,
608                str(le_create_connection_command.peer_address),
609            )
610        )
611
612    def on_disconnection_complete(self, disconnect_command):
613        self.controller.on_link_peripheral_disconnection_complete(
614            disconnect_command, HCI_SUCCESS
615        )
616
617    def disconnect(self, central_address, peripheral_address, disconnect_command):
618        logger.debug(
619            f'disconnect {central_address} -> '
620            f'{peripheral_address}: reason = {disconnect_command.reason}'
621        )
622        self.execute(
623            partial(
624                self.send_targeted_message,
625                peripheral_address,
626                f'disconnect:reason={disconnect_command.reason}',
627            )
628        )
629        asyncio.get_running_loop().call_soon(
630            self.on_disconnection_complete, disconnect_command
631        )
632
633    def on_connection_encrypted(self, _, peripheral_address, rand, ediv, ltk):
634        asyncio.get_running_loop().call_soon(
635            self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk
636        )
637        self.execute(
638            partial(
639                self.send_targeted_message,
640                peripheral_address,
641                f'encrypted:ltk={ltk.hex()}',
642            )
643        )
644