1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19import asyncio
20import collections
21import dataclasses
22import logging
23import struct
24
25from typing import (
26    Any,
27    Awaitable,
28    Callable,
29    Deque,
30    Dict,
31    Optional,
32    Set,
33    cast,
34    TYPE_CHECKING,
35)
36
37from bumble.colors import color
38from bumble.l2cap import L2CAP_PDU
39from bumble.snoop import Snooper
40from bumble import drivers
41from bumble import hci
42from bumble.core import (
43    BT_BR_EDR_TRANSPORT,
44    BT_LE_TRANSPORT,
45    ConnectionPHY,
46    ConnectionParameters,
47)
48from bumble.utils import AbortableEventEmitter
49from bumble.transport.common import TransportLostError
50
51if TYPE_CHECKING:
52    from .transport.common import TransportSink, TransportSource
53
54
55# -----------------------------------------------------------------------------
56# Logging
57# -----------------------------------------------------------------------------
58logger = logging.getLogger(__name__)
59
60
61# -----------------------------------------------------------------------------
62class AclPacketQueue:
63    max_packet_size: int
64
65    def __init__(
66        self,
67        max_packet_size: int,
68        max_in_flight: int,
69        send: Callable[[hci.HCI_Packet], None],
70    ) -> None:
71        self.max_packet_size = max_packet_size
72        self.max_in_flight = max_in_flight
73        self.in_flight = 0
74        self.send = send
75        self.packets: Deque[hci.HCI_AclDataPacket] = collections.deque()
76
77    def enqueue(self, packet: hci.HCI_AclDataPacket) -> None:
78        self.packets.appendleft(packet)
79        self.check_queue()
80
81        if self.packets:
82            logger.debug(
83                f'{self.in_flight} ACL packets in flight, '
84                f'{len(self.packets)} in queue'
85            )
86
87    def check_queue(self) -> None:
88        while self.packets and self.in_flight < self.max_in_flight:
89            packet = self.packets.pop()
90            self.send(packet)
91            self.in_flight += 1
92
93    def on_packets_completed(self, packet_count: int) -> None:
94        if packet_count > self.in_flight:
95            logger.warning(
96                color(
97                    '!!! {packet_count} completed but only '
98                    f'{self.in_flight} in flight'
99                )
100            )
101            packet_count = self.in_flight
102
103        self.in_flight -= packet_count
104        self.check_queue()
105
106
107# -----------------------------------------------------------------------------
108class Connection:
109    def __init__(
110        self, host: Host, handle: int, peer_address: hci.Address, transport: int
111    ):
112        self.host = host
113        self.handle = handle
114        self.peer_address = peer_address
115        self.assembler = hci.HCI_AclDataPacketAssembler(self.on_acl_pdu)
116        self.transport = transport
117        acl_packet_queue: Optional[AclPacketQueue] = (
118            host.le_acl_packet_queue
119            if transport == BT_LE_TRANSPORT
120            else host.acl_packet_queue
121        )
122        assert acl_packet_queue
123        self.acl_packet_queue = acl_packet_queue
124
125    def on_hci_acl_data_packet(self, packet: hci.HCI_AclDataPacket) -> None:
126        self.assembler.feed_packet(packet)
127
128    def on_acl_pdu(self, pdu: bytes) -> None:
129        l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
130        self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
131
132
133# -----------------------------------------------------------------------------
134@dataclasses.dataclass
135class ScoLink:
136    peer_address: hci.Address
137    handle: int
138
139
140# -----------------------------------------------------------------------------
141@dataclasses.dataclass
142class CisLink:
143    peer_address: hci.Address
144    handle: int
145
146
147# -----------------------------------------------------------------------------
148class Host(AbortableEventEmitter):
149    connections: Dict[int, Connection]
150    cis_links: Dict[int, CisLink]
151    sco_links: Dict[int, ScoLink]
152    acl_packet_queue: Optional[AclPacketQueue] = None
153    le_acl_packet_queue: Optional[AclPacketQueue] = None
154    hci_sink: Optional[TransportSink] = None
155    hci_metadata: Dict[str, Any]
156    long_term_key_provider: Optional[
157        Callable[[int, bytes, int], Awaitable[Optional[bytes]]]
158    ]
159    link_key_provider: Optional[Callable[[hci.Address], Awaitable[Optional[bytes]]]]
160
161    def __init__(
162        self,
163        controller_source: Optional[TransportSource] = None,
164        controller_sink: Optional[TransportSink] = None,
165    ) -> None:
166        super().__init__()
167
168        self.hci_metadata = {}
169        self.ready = False  # True when we can accept incoming packets
170        self.connections = {}  # Connections, by connection handle
171        self.cis_links = {}  # CIS links, by connection handle
172        self.sco_links = {}  # SCO links, by connection handle
173        self.pending_command = None
174        self.pending_response: Optional[asyncio.Future[Any]] = None
175        self.number_of_supported_advertising_sets = 0
176        self.maximum_advertising_data_length = 31
177        self.local_version = None
178        self.local_supported_commands = 0
179        self.local_le_features = 0
180        self.local_lmp_features = hci.LmpFeatureMask(0)  # Classic LMP features
181        self.suggested_max_tx_octets = 251  # Max allowed
182        self.suggested_max_tx_time = 2120  # Max allowed
183        self.command_semaphore = asyncio.Semaphore(1)
184        self.long_term_key_provider = None
185        self.link_key_provider = None
186        self.pairing_io_capability_provider = None  # Classic only
187        self.snooper: Optional[Snooper] = None
188
189        # Connect to the source and sink if specified
190        if controller_source:
191            self.set_packet_source(controller_source)
192        if controller_sink:
193            self.set_packet_sink(controller_sink)
194
195    def find_connection_by_bd_addr(
196        self,
197        bd_addr: hci.Address,
198        transport: Optional[int] = None,
199        check_address_type: bool = False,
200    ) -> Optional[Connection]:
201        for connection in self.connections.values():
202            if connection.peer_address.to_bytes() == bd_addr.to_bytes():
203                if (
204                    check_address_type
205                    and connection.peer_address.address_type != bd_addr.address_type
206                ):
207                    continue
208                if transport is None or connection.transport == transport:
209                    return connection
210
211        return None
212
213    async def flush(self) -> None:
214        # Make sure no command is pending
215        await self.command_semaphore.acquire()
216
217        # Flush current host state, then release command semaphore
218        self.emit('flush')
219        self.command_semaphore.release()
220
221    async def reset(self, driver_factory=drivers.get_driver_for_host):
222        if self.ready:
223            self.ready = False
224            await self.flush()
225
226        # Instantiate and init a driver for the host if needed.
227        # NOTE: we don't keep a reference to the driver here, because we don't
228        # currently have a need for the driver later on. But if the driver interface
229        # evolves, it may be required, then, to store a reference to the driver in
230        # an object property.
231        reset_needed = True
232        if driver_factory is not None:
233            if driver := await driver_factory(self):
234                await driver.init_controller()
235                reset_needed = False
236
237        # Send a reset command unless a driver has already done so.
238        if reset_needed:
239            await self.send_command(hci.HCI_Reset_Command(), check_result=True)
240            self.ready = True
241
242        response = await self.send_command(
243            hci.HCI_Read_Local_Supported_Commands_Command(), check_result=True
244        )
245        self.local_supported_commands = int.from_bytes(
246            response.return_parameters.supported_commands, 'little'
247        )
248
249        if self.supports_command(hci.HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
250            response = await self.send_command(
251                hci.HCI_LE_Read_Local_Supported_Features_Command(), check_result=True
252            )
253            self.local_le_features = struct.unpack(
254                '<Q', response.return_parameters.le_features
255            )[0]
256
257        if self.supports_command(hci.HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND):
258            response = await self.send_command(
259                hci.HCI_Read_Local_Version_Information_Command(), check_result=True
260            )
261            self.local_version = response.return_parameters
262
263        if self.supports_command(hci.HCI_READ_LOCAL_EXTENDED_FEATURES_COMMAND):
264            max_page_number = 0
265            page_number = 0
266            lmp_features = 0
267            while page_number <= max_page_number:
268                response = await self.send_command(
269                    hci.HCI_Read_Local_Extended_Features_Command(
270                        page_number=page_number
271                    ),
272                    check_result=True,
273                )
274                lmp_features |= int.from_bytes(
275                    response.return_parameters.extended_lmp_features, 'little'
276                ) << (64 * page_number)
277                max_page_number = response.return_parameters.maximum_page_number
278                page_number += 1
279            self.local_lmp_features = hci.LmpFeatureMask(lmp_features)
280
281        elif self.supports_command(hci.HCI_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
282            response = await self.send_command(
283                hci.HCI_Read_Local_Supported_Features_Command(), check_result=True
284            )
285            self.local_lmp_features = hci.LmpFeatureMask(
286                int.from_bytes(response.return_parameters.lmp_features, 'little')
287            )
288
289        await self.send_command(
290            hci.HCI_Set_Event_Mask_Command(
291                event_mask=hci.HCI_Set_Event_Mask_Command.mask(
292                    [
293                        hci.HCI_INQUIRY_COMPLETE_EVENT,
294                        hci.HCI_INQUIRY_RESULT_EVENT,
295                        hci.HCI_CONNECTION_COMPLETE_EVENT,
296                        hci.HCI_CONNECTION_REQUEST_EVENT,
297                        hci.HCI_DISCONNECTION_COMPLETE_EVENT,
298                        hci.HCI_AUTHENTICATION_COMPLETE_EVENT,
299                        hci.HCI_REMOTE_NAME_REQUEST_COMPLETE_EVENT,
300                        hci.HCI_ENCRYPTION_CHANGE_EVENT,
301                        hci.HCI_CHANGE_CONNECTION_LINK_KEY_COMPLETE_EVENT,
302                        hci.HCI_LINK_KEY_TYPE_CHANGED_EVENT,
303                        hci.HCI_READ_REMOTE_SUPPORTED_FEATURES_COMPLETE_EVENT,
304                        hci.HCI_READ_REMOTE_VERSION_INFORMATION_COMPLETE_EVENT,
305                        hci.HCI_QOS_SETUP_COMPLETE_EVENT,
306                        hci.HCI_HARDWARE_ERROR_EVENT,
307                        hci.HCI_FLUSH_OCCURRED_EVENT,
308                        hci.HCI_ROLE_CHANGE_EVENT,
309                        hci.HCI_MODE_CHANGE_EVENT,
310                        hci.HCI_RETURN_LINK_KEYS_EVENT,
311                        hci.HCI_PIN_CODE_REQUEST_EVENT,
312                        hci.HCI_LINK_KEY_REQUEST_EVENT,
313                        hci.HCI_LINK_KEY_NOTIFICATION_EVENT,
314                        hci.HCI_LOOPBACK_COMMAND_EVENT,
315                        hci.HCI_DATA_BUFFER_OVERFLOW_EVENT,
316                        hci.HCI_MAX_SLOTS_CHANGE_EVENT,
317                        hci.HCI_READ_CLOCK_OFFSET_COMPLETE_EVENT,
318                        hci.HCI_CONNECTION_PACKET_TYPE_CHANGED_EVENT,
319                        hci.HCI_QOS_VIOLATION_EVENT,
320                        hci.HCI_PAGE_SCAN_REPETITION_MODE_CHANGE_EVENT,
321                        hci.HCI_FLOW_SPECIFICATION_COMPLETE_EVENT,
322                        hci.HCI_INQUIRY_RESULT_WITH_RSSI_EVENT,
323                        hci.HCI_READ_REMOTE_EXTENDED_FEATURES_COMPLETE_EVENT,
324                        hci.HCI_SYNCHRONOUS_CONNECTION_COMPLETE_EVENT,
325                        hci.HCI_SYNCHRONOUS_CONNECTION_CHANGED_EVENT,
326                        hci.HCI_SNIFF_SUBRATING_EVENT,
327                        hci.HCI_EXTENDED_INQUIRY_RESULT_EVENT,
328                        hci.HCI_ENCRYPTION_KEY_REFRESH_COMPLETE_EVENT,
329                        hci.HCI_IO_CAPABILITY_REQUEST_EVENT,
330                        hci.HCI_IO_CAPABILITY_RESPONSE_EVENT,
331                        hci.HCI_USER_CONFIRMATION_REQUEST_EVENT,
332                        hci.HCI_USER_PASSKEY_REQUEST_EVENT,
333                        hci.HCI_REMOTE_OOB_DATA_REQUEST_EVENT,
334                        hci.HCI_SIMPLE_PAIRING_COMPLETE_EVENT,
335                        hci.HCI_LINK_SUPERVISION_TIMEOUT_CHANGED_EVENT,
336                        hci.HCI_ENHANCED_FLUSH_COMPLETE_EVENT,
337                        hci.HCI_USER_PASSKEY_NOTIFICATION_EVENT,
338                        hci.HCI_KEYPRESS_NOTIFICATION_EVENT,
339                        hci.HCI_REMOTE_HOST_SUPPORTED_FEATURES_NOTIFICATION_EVENT,
340                        hci.HCI_LE_META_EVENT,
341                    ]
342                )
343            )
344        )
345
346        if (
347            self.local_version is not None
348            and self.local_version.hci_version <= hci.HCI_VERSION_BLUETOOTH_CORE_4_0
349        ):
350            # Some older controllers don't like event masks with bits they don't
351            # understand
352            le_event_mask = bytes.fromhex('1F00000000000000')
353        else:
354            le_event_mask = hci.HCI_LE_Set_Event_Mask_Command.mask(
355                [
356                    hci.HCI_LE_CONNECTION_COMPLETE_EVENT,
357                    hci.HCI_LE_ADVERTISING_REPORT_EVENT,
358                    hci.HCI_LE_CONNECTION_UPDATE_COMPLETE_EVENT,
359                    hci.HCI_LE_READ_REMOTE_FEATURES_COMPLETE_EVENT,
360                    hci.HCI_LE_LONG_TERM_KEY_REQUEST_EVENT,
361                    hci.HCI_LE_REMOTE_CONNECTION_PARAMETER_REQUEST_EVENT,
362                    hci.HCI_LE_DATA_LENGTH_CHANGE_EVENT,
363                    hci.HCI_LE_READ_LOCAL_P_256_PUBLIC_KEY_COMPLETE_EVENT,
364                    hci.HCI_LE_GENERATE_DHKEY_COMPLETE_EVENT,
365                    hci.HCI_LE_ENHANCED_CONNECTION_COMPLETE_EVENT,
366                    hci.HCI_LE_DIRECTED_ADVERTISING_REPORT_EVENT,
367                    hci.HCI_LE_PHY_UPDATE_COMPLETE_EVENT,
368                    hci.HCI_LE_EXTENDED_ADVERTISING_REPORT_EVENT,
369                    hci.HCI_LE_PERIODIC_ADVERTISING_SYNC_ESTABLISHED_EVENT,
370                    hci.HCI_LE_PERIODIC_ADVERTISING_REPORT_EVENT,
371                    hci.HCI_LE_PERIODIC_ADVERTISING_SYNC_LOST_EVENT,
372                    hci.HCI_LE_SCAN_TIMEOUT_EVENT,
373                    hci.HCI_LE_ADVERTISING_SET_TERMINATED_EVENT,
374                    hci.HCI_LE_SCAN_REQUEST_RECEIVED_EVENT,
375                    hci.HCI_LE_CONNECTIONLESS_IQ_REPORT_EVENT,
376                    hci.HCI_LE_CONNECTION_IQ_REPORT_EVENT,
377                    hci.HCI_LE_CTE_REQUEST_FAILED_EVENT,
378                    hci.HCI_LE_PERIODIC_ADVERTISING_SYNC_TRANSFER_RECEIVED_EVENT,
379                    hci.HCI_LE_CIS_ESTABLISHED_EVENT,
380                    hci.HCI_LE_CIS_REQUEST_EVENT,
381                    hci.HCI_LE_CREATE_BIG_COMPLETE_EVENT,
382                    hci.HCI_LE_TERMINATE_BIG_COMPLETE_EVENT,
383                    hci.HCI_LE_BIG_SYNC_ESTABLISHED_EVENT,
384                    hci.HCI_LE_BIG_SYNC_LOST_EVENT,
385                    hci.HCI_LE_REQUEST_PEER_SCA_COMPLETE_EVENT,
386                    hci.HCI_LE_PATH_LOSS_THRESHOLD_EVENT,
387                    hci.HCI_LE_TRANSMIT_POWER_REPORTING_EVENT,
388                    hci.HCI_LE_BIGINFO_ADVERTISING_REPORT_EVENT,
389                    hci.HCI_LE_SUBRATE_CHANGE_EVENT,
390                ]
391            )
392
393        await self.send_command(
394            hci.HCI_LE_Set_Event_Mask_Command(le_event_mask=le_event_mask)
395        )
396
397        if self.supports_command(hci.HCI_READ_BUFFER_SIZE_COMMAND):
398            response = await self.send_command(
399                hci.HCI_Read_Buffer_Size_Command(), check_result=True
400            )
401            hc_acl_data_packet_length = (
402                response.return_parameters.hc_acl_data_packet_length
403            )
404            hc_total_num_acl_data_packets = (
405                response.return_parameters.hc_total_num_acl_data_packets
406            )
407
408            logger.debug(
409                'HCI ACL flow control: '
410                f'hc_acl_data_packet_length={hc_acl_data_packet_length},'
411                f'hc_total_num_acl_data_packets={hc_total_num_acl_data_packets}'
412            )
413
414            self.acl_packet_queue = AclPacketQueue(
415                max_packet_size=hc_acl_data_packet_length,
416                max_in_flight=hc_total_num_acl_data_packets,
417                send=self.send_hci_packet,
418            )
419
420        hc_le_acl_data_packet_length = 0
421        hc_total_num_le_acl_data_packets = 0
422        if self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_COMMAND):
423            response = await self.send_command(
424                hci.HCI_LE_Read_Buffer_Size_Command(), check_result=True
425            )
426            hc_le_acl_data_packet_length = (
427                response.return_parameters.hc_le_acl_data_packet_length
428            )
429            hc_total_num_le_acl_data_packets = (
430                response.return_parameters.hc_total_num_le_acl_data_packets
431            )
432
433            logger.debug(
434                'HCI LE ACL flow control: '
435                f'hc_le_acl_data_packet_length={hc_le_acl_data_packet_length},'
436                f'hc_total_num_le_acl_data_packets={hc_total_num_le_acl_data_packets}'
437            )
438
439        if hc_le_acl_data_packet_length == 0 or hc_total_num_le_acl_data_packets == 0:
440            # LE and Classic share the same queue
441            self.le_acl_packet_queue = self.acl_packet_queue
442        else:
443            # Create a separate queue for LE
444            self.le_acl_packet_queue = AclPacketQueue(
445                max_packet_size=hc_le_acl_data_packet_length,
446                max_in_flight=hc_total_num_le_acl_data_packets,
447                send=self.send_hci_packet,
448            )
449
450        if self.supports_command(
451            hci.HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND
452        ) and self.supports_command(
453            hci.HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND
454        ):
455            response = await self.send_command(
456                hci.HCI_LE_Read_Suggested_Default_Data_Length_Command()
457            )
458            suggested_max_tx_octets = response.return_parameters.suggested_max_tx_octets
459            suggested_max_tx_time = response.return_parameters.suggested_max_tx_time
460            if (
461                suggested_max_tx_octets != self.suggested_max_tx_octets
462                or suggested_max_tx_time != self.suggested_max_tx_time
463            ):
464                await self.send_command(
465                    hci.HCI_LE_Write_Suggested_Default_Data_Length_Command(
466                        suggested_max_tx_octets=self.suggested_max_tx_octets,
467                        suggested_max_tx_time=self.suggested_max_tx_time,
468                    )
469                )
470
471        if self.supports_command(
472            hci.HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND
473        ):
474            response = await self.send_command(
475                hci.HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command(),
476                check_result=True,
477            )
478            self.number_of_supported_advertising_sets = (
479                response.return_parameters.num_supported_advertising_sets
480            )
481
482        if self.supports_command(
483            hci.HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND
484        ):
485            response = await self.send_command(
486                hci.HCI_LE_Read_Maximum_Advertising_Data_Length_Command(),
487                check_result=True,
488            )
489            self.maximum_advertising_data_length = (
490                response.return_parameters.max_advertising_data_length
491            )
492
493    @property
494    def controller(self) -> Optional[TransportSink]:
495        return self.hci_sink
496
497    @controller.setter
498    def controller(self, controller) -> None:
499        self.set_packet_sink(controller)
500        if controller:
501            self.set_packet_source(controller)
502
503    def set_packet_sink(self, sink: Optional[TransportSink]) -> None:
504        self.hci_sink = sink
505
506    def set_packet_source(self, source: TransportSource) -> None:
507        source.set_packet_sink(self)
508        self.hci_metadata = getattr(source, 'metadata', self.hci_metadata)
509
510    def send_hci_packet(self, packet: hci.HCI_Packet) -> None:
511        logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {packet}')
512        if self.snooper:
513            self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)
514        if self.hci_sink:
515            self.hci_sink.on_packet(bytes(packet))
516
517    async def send_command(
518        self, command, check_result=False, response_timeout: Optional[int] = None
519    ):
520        # Wait until we can send (only one pending command at a time)
521        async with self.command_semaphore:
522            assert self.pending_command is None
523            assert self.pending_response is None
524
525            # Create a future value to hold the eventual response
526            self.pending_response = asyncio.get_running_loop().create_future()
527            self.pending_command = command
528
529            try:
530                self.send_hci_packet(command)
531                await asyncio.wait_for(self.pending_response, timeout=response_timeout)
532                response = self.pending_response.result()
533
534                # Check the return parameters if required
535                if check_result:
536                    if isinstance(response, hci.HCI_Command_Status_Event):
537                        status = response.status  # type: ignore[attr-defined]
538                    elif isinstance(response.return_parameters, int):
539                        status = response.return_parameters
540                    elif isinstance(response.return_parameters, bytes):
541                        # return parameters first field is a one byte status code
542                        status = response.return_parameters[0]
543                    else:
544                        status = response.return_parameters.status
545
546                    if status != hci.HCI_SUCCESS:
547                        logger.warning(
548                            f'{command.name} failed '
549                            f'({hci.HCI_Constant.error_name(status)})'
550                        )
551                        raise hci.HCI_Error(status)
552
553                return response
554            except Exception as error:
555                logger.warning(
556                    f'{color("!!! Exception while sending command:", "red")} {error}'
557                )
558                raise error
559            finally:
560                self.pending_command = None
561                self.pending_response = None
562
563    # Use this method to send a command from a task
564    def send_command_sync(self, command: hci.HCI_Command) -> None:
565        async def send_command(command: hci.HCI_Command) -> None:
566            await self.send_command(command)
567
568        asyncio.create_task(send_command(command))
569
570    def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
571        if not (connection := self.connections.get(connection_handle)):
572            logger.warning(f'connection 0x{connection_handle:04X} not found')
573            return
574        packet_queue = connection.acl_packet_queue
575        if packet_queue is None:
576            logger.warning(
577                f'no ACL packet queue for connection 0x{connection_handle:04X}'
578            )
579            return
580
581        # Create a PDU
582        l2cap_pdu = bytes(L2CAP_PDU(cid, pdu))
583
584        # Send the data to the controller via ACL packets
585        bytes_remaining = len(l2cap_pdu)
586        offset = 0
587        pb_flag = 0
588        while bytes_remaining:
589            data_total_length = min(bytes_remaining, packet_queue.max_packet_size)
590            acl_packet = hci.HCI_AclDataPacket(
591                connection_handle=connection_handle,
592                pb_flag=pb_flag,
593                bc_flag=0,
594                data_total_length=data_total_length,
595                data=l2cap_pdu[offset : offset + data_total_length],
596            )
597            logger.debug(f'>>> ACL packet enqueue: (CID={cid}) {acl_packet}')
598            packet_queue.enqueue(acl_packet)
599            pb_flag = 1
600            offset += data_total_length
601            bytes_remaining -= data_total_length
602
603    def supports_command(self, op_code: int) -> bool:
604        return (
605            self.local_supported_commands
606            & hci.HCI_SUPPORTED_COMMANDS_MASKS.get(op_code, 0)
607        ) != 0
608
609    @property
610    def supported_commands(self) -> Set[int]:
611        return set(
612            op_code
613            for op_code, mask in hci.HCI_SUPPORTED_COMMANDS_MASKS.items()
614            if self.local_supported_commands & mask
615        )
616
617    def supports_le_features(self, feature: hci.LeFeatureMask) -> bool:
618        return (self.local_le_features & feature) == feature
619
620    def supports_lmp_features(self, feature: hci.LmpFeatureMask) -> bool:
621        return self.local_lmp_features & (feature) == feature
622
623    @property
624    def supported_le_features(self):
625        return [
626            feature for feature in range(64) if self.local_le_features & (1 << feature)
627        ]
628
629    # Packet Sink protocol (packets coming from the controller via HCI)
630    def on_packet(self, packet: bytes) -> None:
631        try:
632            hci_packet = hci.HCI_Packet.from_bytes(packet)
633        except Exception as error:
634            logger.warning(f'!!! error parsing packet from bytes: {error}')
635            return
636
637        if self.ready or (
638            isinstance(hci_packet, hci.HCI_Command_Complete_Event)
639            and hci_packet.command_opcode == hci.HCI_RESET_COMMAND
640        ):
641            self.on_hci_packet(hci_packet)
642        else:
643            logger.debug(
644                f'reset not done, ignoring packet from controller: {hci_packet}'
645            )
646
647    def on_transport_lost(self):
648        # Called by the source when the transport has been lost.
649        if self.pending_response:
650            self.pending_response.set_exception(TransportLostError('transport lost'))
651
652        self.emit('flush')
653
654    def on_hci_packet(self, packet: hci.HCI_Packet) -> None:
655        logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}')
656
657        if self.snooper:
658            self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST)
659
660        # If the packet is a command, invoke the handler for this packet
661        if packet.hci_packet_type == hci.HCI_COMMAND_PACKET:
662            self.on_hci_command_packet(cast(hci.HCI_Command, packet))
663        elif packet.hci_packet_type == hci.HCI_EVENT_PACKET:
664            self.on_hci_event_packet(cast(hci.HCI_Event, packet))
665        elif packet.hci_packet_type == hci.HCI_ACL_DATA_PACKET:
666            self.on_hci_acl_data_packet(cast(hci.HCI_AclDataPacket, packet))
667        elif packet.hci_packet_type == hci.HCI_SYNCHRONOUS_DATA_PACKET:
668            self.on_hci_sco_data_packet(cast(hci.HCI_SynchronousDataPacket, packet))
669        elif packet.hci_packet_type == hci.HCI_ISO_DATA_PACKET:
670            self.on_hci_iso_data_packet(cast(hci.HCI_IsoDataPacket, packet))
671        else:
672            logger.warning(f'!!! unknown packet type {packet.hci_packet_type}')
673
674    def on_hci_command_packet(self, command: hci.HCI_Command) -> None:
675        logger.warning(f'!!! unexpected command packet: {command}')
676
677    def on_hci_event_packet(self, event: hci.HCI_Event) -> None:
678        handler_name = f'on_{event.name.lower()}'
679        handler = getattr(self, handler_name, self.on_hci_event)
680        handler(event)
681
682    def on_hci_acl_data_packet(self, packet: hci.HCI_AclDataPacket) -> None:
683        # Look for the connection to which this data belongs
684        if connection := self.connections.get(packet.connection_handle):
685            connection.on_hci_acl_data_packet(packet)
686
687    def on_hci_sco_data_packet(self, packet: hci.HCI_SynchronousDataPacket) -> None:
688        # Experimental
689        self.emit('sco_packet', packet.connection_handle, packet)
690
691    def on_hci_iso_data_packet(self, packet: hci.HCI_IsoDataPacket) -> None:
692        # Experimental
693        self.emit('iso_packet', packet.connection_handle, packet)
694
695    def on_l2cap_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
696        self.emit('l2cap_pdu', connection.handle, cid, pdu)
697
698    def on_command_processed(self, event):
699        if self.pending_response:
700            # Check that it is what we were expecting
701            if self.pending_command.op_code != event.command_opcode:
702                logger.warning(
703                    '!!! command result mismatch, expected '
704                    f'0x{self.pending_command.op_code:X} but got '
705                    f'0x{event.command_opcode:X}'
706                )
707
708            self.pending_response.set_result(event)
709        else:
710            logger.warning('!!! no pending response future to set')
711
712    ############################################################
713    # HCI handlers
714    ############################################################
715    def on_hci_event(self, event):
716        logger.warning(f'{color(f"--- Ignoring event {event}", "red")}')
717
718    def on_hci_command_complete_event(self, event):
719        if event.command_opcode == 0:
720            # This is used just for the Num_HCI_Command_Packets field, not related to
721            # an actual command
722            logger.debug('no-command event')
723            return
724
725        return self.on_command_processed(event)
726
727    def on_hci_command_status_event(self, event):
728        return self.on_command_processed(event)
729
730    def on_hci_number_of_completed_packets_event(self, event):
731        for connection_handle, num_completed_packets in zip(
732            event.connection_handles, event.num_completed_packets
733        ):
734            if connection := self.connections.get(connection_handle):
735                connection.acl_packet_queue.on_packets_completed(num_completed_packets)
736            elif not (
737                self.cis_links.get(connection_handle)
738                or self.sco_links.get(connection_handle)
739            ):
740                logger.warning(
741                    'received packet completion event for unknown handle '
742                    f'0x{connection_handle:04X}'
743                )
744
745    # Classic only
746    def on_hci_connection_request_event(self, event):
747        # Notify the listeners
748        self.emit(
749            'connection_request',
750            event.bd_addr,
751            event.class_of_device,
752            event.link_type,
753        )
754
755    def on_hci_le_connection_complete_event(self, event):
756        # Check if this is a cancellation
757        if event.status == hci.HCI_SUCCESS:
758            # Create/update the connection
759            logger.debug(
760                f'### LE CONNECTION: [0x{event.connection_handle:04X}] '
761                f'{event.peer_address} as {hci.HCI_Constant.role_name(event.role)}'
762            )
763
764            connection = self.connections.get(event.connection_handle)
765            if connection is None:
766                connection = Connection(
767                    self,
768                    event.connection_handle,
769                    event.peer_address,
770                    BT_LE_TRANSPORT,
771                )
772                self.connections[event.connection_handle] = connection
773
774            # Notify the client
775            connection_parameters = ConnectionParameters(
776                event.connection_interval,
777                event.peripheral_latency,
778                event.supervision_timeout,
779            )
780            self.emit(
781                'connection',
782                event.connection_handle,
783                BT_LE_TRANSPORT,
784                event.peer_address,
785                getattr(event, 'local_resolvable_private_address', None),
786                getattr(event, 'peer_resolvable_private_address', None),
787                event.role,
788                connection_parameters,
789            )
790        else:
791            logger.debug(f'### CONNECTION FAILED: {event.status}')
792
793            # Notify the listeners
794            self.emit(
795                'connection_failure', BT_LE_TRANSPORT, event.peer_address, event.status
796            )
797
798    def on_hci_le_enhanced_connection_complete_event(self, event):
799        # Just use the same implementation as for the non-enhanced event for now
800        self.on_hci_le_connection_complete_event(event)
801
802    def on_hci_le_enhanced_connection_complete_v2_event(self, event):
803        # Just use the same implementation as for the v1 event for now
804        self.on_hci_le_enhanced_connection_complete_event(event)
805
806    def on_hci_connection_complete_event(self, event):
807        if event.status == hci.HCI_SUCCESS:
808            # Create/update the connection
809            logger.debug(
810                f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] '
811                f'{event.bd_addr}'
812            )
813
814            connection = self.connections.get(event.connection_handle)
815            if connection is None:
816                connection = Connection(
817                    self,
818                    event.connection_handle,
819                    event.bd_addr,
820                    BT_BR_EDR_TRANSPORT,
821                )
822                self.connections[event.connection_handle] = connection
823
824            # Notify the client
825            self.emit(
826                'connection',
827                event.connection_handle,
828                BT_BR_EDR_TRANSPORT,
829                event.bd_addr,
830                None,
831                None,
832                None,
833                None,
834            )
835        else:
836            logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}')
837
838            # Notify the client
839            self.emit(
840                'connection_failure', BT_BR_EDR_TRANSPORT, event.bd_addr, event.status
841            )
842
843    def on_hci_disconnection_complete_event(self, event):
844        # Find the connection
845        handle = event.connection_handle
846        if (
847            connection := (
848                self.connections.get(handle)
849                or self.cis_links.get(handle)
850                or self.sco_links.get(handle)
851            )
852        ) is None:
853            logger.warning('!!! DISCONNECTION COMPLETE: unknown handle')
854            return
855
856        if event.status == hci.HCI_SUCCESS:
857            logger.debug(
858                f'### DISCONNECTION: [0x{handle:04X}] '
859                f'{connection.peer_address} '
860                f'reason={event.reason}'
861            )
862
863            # Notify the listeners
864            self.emit('disconnection', handle, event.reason)
865
866            # Remove the handle reference
867            _ = (
868                self.connections.pop(handle, 0)
869                or self.cis_links.pop(handle, 0)
870                or self.sco_links.pop(handle, 0)
871            )
872        else:
873            logger.debug(f'### DISCONNECTION FAILED: {event.status}')
874
875            # Notify the listeners
876            self.emit('disconnection_failure', handle, event.status)
877
878    def on_hci_le_connection_update_complete_event(self, event):
879        if (connection := self.connections.get(event.connection_handle)) is None:
880            logger.warning('!!! CONNECTION PARAMETERS UPDATE COMPLETE: unknown handle')
881            return
882
883        # Notify the client
884        if event.status == hci.HCI_SUCCESS:
885            connection_parameters = ConnectionParameters(
886                event.connection_interval,
887                event.peripheral_latency,
888                event.supervision_timeout,
889            )
890            self.emit(
891                'connection_parameters_update', connection.handle, connection_parameters
892            )
893        else:
894            self.emit(
895                'connection_parameters_update_failure', connection.handle, event.status
896            )
897
898    def on_hci_le_phy_update_complete_event(self, event):
899        if (connection := self.connections.get(event.connection_handle)) is None:
900            logger.warning('!!! CONNECTION PHY UPDATE COMPLETE: unknown handle')
901            return
902
903        # Notify the client
904        if event.status == hci.HCI_SUCCESS:
905            connection_phy = ConnectionPHY(event.tx_phy, event.rx_phy)
906            self.emit('connection_phy_update', connection.handle, connection_phy)
907        else:
908            self.emit('connection_phy_update_failure', connection.handle, event.status)
909
910    def on_hci_le_advertising_report_event(self, event):
911        for report in event.reports:
912            self.emit('advertising_report', report)
913
914    def on_hci_le_extended_advertising_report_event(self, event):
915        self.on_hci_le_advertising_report_event(event)
916
917    def on_hci_le_advertising_set_terminated_event(self, event):
918        self.emit(
919            'advertising_set_termination',
920            event.status,
921            event.advertising_handle,
922            event.connection_handle,
923            event.num_completed_extended_advertising_events,
924        )
925
926    def on_hci_le_periodic_advertising_sync_established_event(self, event):
927        self.emit(
928            'periodic_advertising_sync_establishment',
929            event.status,
930            event.sync_handle,
931            event.advertising_sid,
932            event.advertiser_address,
933            event.advertiser_phy,
934            event.periodic_advertising_interval,
935            event.advertiser_clock_accuracy,
936        )
937
938    def on_hci_le_periodic_advertising_sync_lost_event(self, event):
939        self.emit('periodic_advertising_sync_loss', event.sync_handle)
940
941    def on_hci_le_periodic_advertising_report_event(self, event):
942        self.emit('periodic_advertising_report', event.sync_handle, event)
943
944    def on_hci_le_biginfo_advertising_report_event(self, event):
945        self.emit('biginfo_advertising_report', event.sync_handle, event)
946
947    def on_hci_le_cis_request_event(self, event):
948        self.emit(
949            'cis_request',
950            event.acl_connection_handle,
951            event.cis_connection_handle,
952            event.cig_id,
953            event.cis_id,
954        )
955
956    def on_hci_le_cis_established_event(self, event):
957        # The remaining parameters are unused for now.
958        if event.status == hci.HCI_SUCCESS:
959            self.cis_links[event.connection_handle] = CisLink(
960                handle=event.connection_handle,
961                peer_address=hci.Address.ANY,
962            )
963            self.emit('cis_establishment', event.connection_handle)
964        else:
965            self.emit(
966                'cis_establishment_failure', event.connection_handle, event.status
967            )
968
969    def on_hci_le_remote_connection_parameter_request_event(self, event):
970        if event.connection_handle not in self.connections:
971            logger.warning('!!! REMOTE CONNECTION PARAMETER REQUEST: unknown handle')
972            return
973
974        # For now, just accept everything
975        # TODO: delegate the decision
976        self.send_command_sync(
977            hci.HCI_LE_Remote_Connection_Parameter_Request_Reply_Command(
978                connection_handle=event.connection_handle,
979                interval_min=event.interval_min,
980                interval_max=event.interval_max,
981                max_latency=event.max_latency,
982                timeout=event.timeout,
983                min_ce_length=0,
984                max_ce_length=0,
985            )
986        )
987
988    def on_hci_le_long_term_key_request_event(self, event):
989        if (connection := self.connections.get(event.connection_handle)) is None:
990            logger.warning('!!! LE LONG TERM KEY REQUEST: unknown handle')
991            return
992
993        async def send_long_term_key():
994            if self.long_term_key_provider is None:
995                logger.debug('no long term key provider')
996                long_term_key = None
997            else:
998                long_term_key = await self.abort_on(
999                    'flush',
1000                    # pylint: disable-next=not-callable
1001                    self.long_term_key_provider(
1002                        connection.handle,
1003                        event.random_number,
1004                        event.encryption_diversifier,
1005                    ),
1006                )
1007            if long_term_key:
1008                response = hci.HCI_LE_Long_Term_Key_Request_Reply_Command(
1009                    connection_handle=event.connection_handle,
1010                    long_term_key=long_term_key,
1011                )
1012            else:
1013                response = hci.HCI_LE_Long_Term_Key_Request_Negative_Reply_Command(
1014                    connection_handle=event.connection_handle
1015                )
1016
1017            await self.send_command(response)
1018
1019        asyncio.create_task(send_long_term_key())
1020
1021    def on_hci_synchronous_connection_complete_event(self, event):
1022        if event.status == hci.HCI_SUCCESS:
1023            # Create/update the connection
1024            logger.debug(
1025                f'### SCO CONNECTION: [0x{event.connection_handle:04X}] '
1026                f'{event.bd_addr}'
1027            )
1028
1029            self.sco_links[event.connection_handle] = ScoLink(
1030                peer_address=event.bd_addr,
1031                handle=event.connection_handle,
1032            )
1033
1034            # Notify the client
1035            self.emit(
1036                'sco_connection',
1037                event.bd_addr,
1038                event.connection_handle,
1039                event.link_type,
1040            )
1041        else:
1042            logger.debug(f'### SCO CONNECTION FAILED: {event.status}')
1043
1044            # Notify the client
1045            self.emit('sco_connection_failure', event.bd_addr, event.status)
1046
1047    def on_hci_synchronous_connection_changed_event(self, event):
1048        pass
1049
1050    def on_hci_role_change_event(self, event):
1051        if event.status == hci.HCI_SUCCESS:
1052            logger.debug(
1053                f'role change for {event.bd_addr}: '
1054                f'{hci.HCI_Constant.role_name(event.new_role)}'
1055            )
1056            self.emit('role_change', event.bd_addr, event.new_role)
1057        else:
1058            logger.debug(
1059                f'role change for {event.bd_addr} failed: '
1060                f'{hci.HCI_Constant.error_name(event.status)}'
1061            )
1062            self.emit('role_change_failure', event.bd_addr, event.status)
1063
1064    def on_hci_le_data_length_change_event(self, event):
1065        self.emit(
1066            'connection_data_length_change',
1067            event.connection_handle,
1068            event.max_tx_octets,
1069            event.max_tx_time,
1070            event.max_rx_octets,
1071            event.max_rx_time,
1072        )
1073
1074    def on_hci_authentication_complete_event(self, event):
1075        # Notify the client
1076        if event.status == hci.HCI_SUCCESS:
1077            self.emit('connection_authentication', event.connection_handle)
1078        else:
1079            self.emit(
1080                'connection_authentication_failure',
1081                event.connection_handle,
1082                event.status,
1083            )
1084
1085    def on_hci_encryption_change_event(self, event):
1086        # Notify the client
1087        if event.status == hci.HCI_SUCCESS:
1088            self.emit(
1089                'connection_encryption_change',
1090                event.connection_handle,
1091                event.encryption_enabled,
1092            )
1093        else:
1094            self.emit(
1095                'connection_encryption_failure', event.connection_handle, event.status
1096            )
1097
1098    def on_hci_encryption_key_refresh_complete_event(self, event):
1099        # Notify the client
1100        if event.status == hci.HCI_SUCCESS:
1101            self.emit('connection_encryption_key_refresh', event.connection_handle)
1102        else:
1103            self.emit(
1104                'connection_encryption_key_refresh_failure',
1105                event.connection_handle,
1106                event.status,
1107            )
1108
1109    def on_hci_link_supervision_timeout_changed_event(self, event):
1110        pass
1111
1112    def on_hci_max_slots_change_event(self, event):
1113        pass
1114
1115    def on_hci_page_scan_repetition_mode_change_event(self, event):
1116        pass
1117
1118    def on_hci_link_key_notification_event(self, event):
1119        logger.debug(
1120            f'link key for {event.bd_addr}: {event.link_key.hex()}, '
1121            f'type={hci.HCI_Constant.link_key_type_name(event.key_type)}'
1122        )
1123        self.emit('link_key', event.bd_addr, event.link_key, event.key_type)
1124
1125    def on_hci_simple_pairing_complete_event(self, event):
1126        logger.debug(
1127            f'simple pairing complete for {event.bd_addr}: '
1128            f'status={hci.HCI_Constant.status_name(event.status)}'
1129        )
1130        if event.status == hci.HCI_SUCCESS:
1131            self.emit('classic_pairing', event.bd_addr)
1132        else:
1133            self.emit('classic_pairing_failure', event.bd_addr, event.status)
1134
1135    def on_hci_pin_code_request_event(self, event):
1136        self.emit('pin_code_request', event.bd_addr)
1137
1138    def on_hci_link_key_request_event(self, event):
1139        async def send_link_key():
1140            if self.link_key_provider is None:
1141                logger.debug('no link key provider')
1142                link_key = None
1143            else:
1144                link_key = await self.abort_on(
1145                    'flush',
1146                    # pylint: disable-next=not-callable
1147                    self.link_key_provider(event.bd_addr),
1148                )
1149            if link_key:
1150                response = hci.HCI_Link_Key_Request_Reply_Command(
1151                    bd_addr=event.bd_addr, link_key=link_key
1152                )
1153            else:
1154                response = hci.HCI_Link_Key_Request_Negative_Reply_Command(
1155                    bd_addr=event.bd_addr
1156                )
1157
1158            await self.send_command(response)
1159
1160        asyncio.create_task(send_link_key())
1161
1162    def on_hci_io_capability_request_event(self, event):
1163        self.emit('authentication_io_capability_request', event.bd_addr)
1164
1165    def on_hci_io_capability_response_event(self, event):
1166        self.emit(
1167            'authentication_io_capability_response',
1168            event.bd_addr,
1169            event.io_capability,
1170            event.authentication_requirements,
1171        )
1172
1173    def on_hci_user_confirmation_request_event(self, event):
1174        self.emit(
1175            'authentication_user_confirmation_request',
1176            event.bd_addr,
1177            event.numeric_value,
1178        )
1179
1180    def on_hci_user_passkey_request_event(self, event):
1181        self.emit('authentication_user_passkey_request', event.bd_addr)
1182
1183    def on_hci_user_passkey_notification_event(self, event):
1184        self.emit(
1185            'authentication_user_passkey_notification', event.bd_addr, event.passkey
1186        )
1187
1188    def on_hci_inquiry_complete_event(self, _event):
1189        self.emit('inquiry_complete')
1190
1191    def on_hci_inquiry_result_with_rssi_event(self, event):
1192        for response in event.responses:
1193            self.emit(
1194                'inquiry_result',
1195                response.bd_addr,
1196                response.class_of_device,
1197                b'',
1198                response.rssi,
1199            )
1200
1201    def on_hci_extended_inquiry_result_event(self, event):
1202        self.emit(
1203            'inquiry_result',
1204            event.bd_addr,
1205            event.class_of_device,
1206            event.extended_inquiry_response,
1207            event.rssi,
1208        )
1209
1210    def on_hci_remote_name_request_complete_event(self, event):
1211        if event.status != hci.HCI_SUCCESS:
1212            self.emit('remote_name_failure', event.bd_addr, event.status)
1213        else:
1214            utf8_name = event.remote_name
1215            terminator = utf8_name.find(0)
1216            if terminator >= 0:
1217                utf8_name = utf8_name[0:terminator]
1218
1219            self.emit('remote_name', event.bd_addr, utf8_name)
1220
1221    def on_hci_remote_host_supported_features_notification_event(self, event):
1222        self.emit(
1223            'remote_host_supported_features',
1224            event.bd_addr,
1225            event.host_supported_features,
1226        )
1227
1228    def on_hci_le_read_remote_features_complete_event(self, event):
1229        if event.status != hci.HCI_SUCCESS:
1230            self.emit(
1231                'le_remote_features_failure', event.connection_handle, event.status
1232            )
1233        else:
1234            self.emit(
1235                'le_remote_features',
1236                event.connection_handle,
1237                int.from_bytes(event.le_features, 'little'),
1238            )
1239