1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19
20import logging
21import asyncio
22import collections
23import dataclasses
24import enum
25from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
26from typing_extensions import Self
27
28from pyee import EventEmitter
29
30from bumble import core
31from bumble import l2cap
32from bumble import sdp
33from .colors import color
34from .core import (
35    UUID,
36    BT_RFCOMM_PROTOCOL_ID,
37    BT_BR_EDR_TRANSPORT,
38    BT_L2CAP_PROTOCOL_ID,
39    InvalidArgumentError,
40    InvalidStateError,
41    InvalidPacketError,
42    ProtocolError,
43)
44
45if TYPE_CHECKING:
46    from bumble.device import Device, Connection
47
48# -----------------------------------------------------------------------------
49# Logging
50# -----------------------------------------------------------------------------
51logger = logging.getLogger(__name__)
52
53
54# -----------------------------------------------------------------------------
55# Constants
56# -----------------------------------------------------------------------------
57# fmt: off
58
59RFCOMM_PSM = 0x0003
60DEFAULT_RX_QUEUE_SIZE = 32
61
62class FrameType(enum.IntEnum):
63    SABM = 0x2F  # Control field [1,1,1,1,_,1,0,0] LSB-first
64    UA   = 0x63  # Control field [0,1,1,0,_,0,1,1] LSB-first
65    DM   = 0x0F  # Control field [1,1,1,1,_,0,0,0] LSB-first
66    DISC = 0x43  # Control field [0,1,0,_,0,0,1,1] LSB-first
67    UIH  = 0xEF  # Control field [1,1,1,_,1,1,1,1] LSB-first
68    UI   = 0x03  # Control field [0,0,0,_,0,0,1,1] LSB-first
69
70class MccType(enum.IntEnum):
71    PN  = 0x20
72    MSC = 0x38
73
74
75# FCS CRC
76CRC_TABLE = bytes([
77    0X00, 0X91, 0XE3, 0X72, 0X07, 0X96, 0XE4, 0X75,
78    0X0E, 0X9F, 0XED, 0X7C, 0X09, 0X98, 0XEA, 0X7B,
79    0X1C, 0X8D, 0XFF, 0X6E, 0X1B, 0X8A, 0XF8, 0X69,
80    0X12, 0X83, 0XF1, 0X60, 0X15, 0X84, 0XF6, 0X67,
81    0X38, 0XA9, 0XDB, 0X4A, 0X3F, 0XAE, 0XDC, 0X4D,
82    0X36, 0XA7, 0XD5, 0X44, 0X31, 0XA0, 0XD2, 0X43,
83    0X24, 0XB5, 0XC7, 0X56, 0X23, 0XB2, 0XC0, 0X51,
84    0X2A, 0XBB, 0XC9, 0X58, 0X2D, 0XBC, 0XCE, 0X5F,
85    0X70, 0XE1, 0X93, 0X02, 0X77, 0XE6, 0X94, 0X05,
86    0X7E, 0XEF, 0X9D, 0X0C, 0X79, 0XE8, 0X9A, 0X0B,
87    0X6C, 0XFD, 0X8F, 0X1E, 0X6B, 0XFA, 0X88, 0X19,
88    0X62, 0XF3, 0X81, 0X10, 0X65, 0XF4, 0X86, 0X17,
89    0X48, 0XD9, 0XAB, 0X3A, 0X4F, 0XDE, 0XAC, 0X3D,
90    0X46, 0XD7, 0XA5, 0X34, 0X41, 0XD0, 0XA2, 0X33,
91    0X54, 0XC5, 0XB7, 0X26, 0X53, 0XC2, 0XB0, 0X21,
92    0X5A, 0XCB, 0XB9, 0X28, 0X5D, 0XCC, 0XBE, 0X2F,
93    0XE0, 0X71, 0X03, 0X92, 0XE7, 0X76, 0X04, 0X95,
94    0XEE, 0X7F, 0X0D, 0X9C, 0XE9, 0X78, 0X0A, 0X9B,
95    0XFC, 0X6D, 0X1F, 0X8E, 0XFB, 0X6A, 0X18, 0X89,
96    0XF2, 0X63, 0X11, 0X80, 0XF5, 0X64, 0X16, 0X87,
97    0XD8, 0X49, 0X3B, 0XAA, 0XDF, 0X4E, 0X3C, 0XAD,
98    0XD6, 0X47, 0X35, 0XA4, 0XD1, 0X40, 0X32, 0XA3,
99    0XC4, 0X55, 0X27, 0XB6, 0XC3, 0X52, 0X20, 0XB1,
100    0XCA, 0X5B, 0X29, 0XB8, 0XCD, 0X5C, 0X2E, 0XBF,
101    0X90, 0X01, 0X73, 0XE2, 0X97, 0X06, 0X74, 0XE5,
102    0X9E, 0X0F, 0X7D, 0XEC, 0X99, 0X08, 0X7A, 0XEB,
103    0X8C, 0X1D, 0X6F, 0XFE, 0X8B, 0X1A, 0X68, 0XF9,
104    0X82, 0X13, 0X61, 0XF0, 0X85, 0X14, 0X66, 0XF7,
105    0XA8, 0X39, 0X4B, 0XDA, 0XAF, 0X3E, 0X4C, 0XDD,
106    0XA6, 0X37, 0X45, 0XD4, 0XA1, 0X30, 0X42, 0XD3,
107    0XB4, 0X25, 0X57, 0XC6, 0XB3, 0X22, 0X50, 0XC1,
108    0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF
109])
110
111RFCOMM_DEFAULT_L2CAP_MTU        = 2048
112RFCOMM_DEFAULT_INITIAL_CREDITS  = 7
113RFCOMM_DEFAULT_MAX_CREDITS      = 32
114RFCOMM_DEFAULT_CREDIT_THRESHOLD = RFCOMM_DEFAULT_MAX_CREDITS // 2
115RFCOMM_DEFAULT_MAX_FRAME_SIZE   = 2000
116
117RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
118RFCOMM_DYNAMIC_CHANNEL_NUMBER_END   = 30
119
120# fmt: on
121
122
123# -----------------------------------------------------------------------------
124def make_service_sdp_records(
125    service_record_handle: int, channel: int, uuid: Optional[UUID] = None
126) -> List[sdp.ServiceAttribute]:
127    """
128    Create SDP records for an RFComm service given a channel number and an
129    optional UUID. A Service Class Attribute is included only if the UUID is not None.
130    """
131    records = [
132        sdp.ServiceAttribute(
133            sdp.SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
134            sdp.DataElement.unsigned_integer_32(service_record_handle),
135        ),
136        sdp.ServiceAttribute(
137            sdp.SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
138            sdp.DataElement.sequence(
139                [sdp.DataElement.uuid(sdp.SDP_PUBLIC_BROWSE_ROOT)]
140            ),
141        ),
142        sdp.ServiceAttribute(
143            sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
144            sdp.DataElement.sequence(
145                [
146                    sdp.DataElement.sequence(
147                        [sdp.DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]
148                    ),
149                    sdp.DataElement.sequence(
150                        [
151                            sdp.DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
152                            sdp.DataElement.unsigned_integer_8(channel),
153                        ]
154                    ),
155                ]
156            ),
157        ),
158    ]
159
160    if uuid:
161        records.append(
162            sdp.ServiceAttribute(
163                sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
164                sdp.DataElement.sequence([sdp.DataElement.uuid(uuid)]),
165            )
166        )
167
168    return records
169
170
171# -----------------------------------------------------------------------------
172async def find_rfcomm_channels(connection: Connection) -> Dict[int, List[UUID]]:
173    """Searches all RFCOMM channels and their associated UUID from SDP service records.
174
175    Args:
176        connection: ACL connection to make SDP search.
177
178    Returns:
179        Dictionary mapping from channel number to service class UUID list.
180    """
181    results = {}
182    async with sdp.Client(connection) as sdp_client:
183        search_result = await sdp_client.search_attributes(
184            uuids=[core.BT_RFCOMM_PROTOCOL_ID],
185            attribute_ids=[
186                sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
187                sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
188            ],
189        )
190        for attribute_lists in search_result:
191            service_classes: List[UUID] = []
192            channel: Optional[int] = None
193            for attribute in attribute_lists:
194                # The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
195                if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
196                    protocol_descriptor_list = attribute.value.value
197                    channel = protocol_descriptor_list[1].value[1].value
198                elif attribute.id == sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID:
199                    service_class_id_list = attribute.value.value
200                    service_classes = [
201                        service_class.value for service_class in service_class_id_list
202                    ]
203            if not service_classes or not channel:
204                logger.warning(f"Bad result {attribute_lists}.")
205            else:
206                results[channel] = service_classes
207    return results
208
209
210# -----------------------------------------------------------------------------
211async def find_rfcomm_channel_with_uuid(
212    connection: Connection, uuid: str | UUID
213) -> Optional[int]:
214    """Searches an RFCOMM channel associated with given UUID from service records.
215
216    Args:
217        connection: ACL connection to make SDP search.
218        uuid: UUID of service record to search for.
219
220    Returns:
221        RFCOMM channel number if found, otherwise None.
222    """
223    if isinstance(uuid, str):
224        uuid = UUID(uuid)
225    return next(
226        (
227            channel
228            for channel, class_id_list in (
229                await find_rfcomm_channels(connection)
230            ).items()
231            if uuid in class_id_list
232        ),
233        None,
234    )
235
236
237# -----------------------------------------------------------------------------
238def compute_fcs(buffer: bytes) -> int:
239    result = 0xFF
240    for byte in buffer:
241        result = CRC_TABLE[result ^ byte]
242    return 0xFF - result
243
244
245# -----------------------------------------------------------------------------
246class RFCOMM_Frame:
247    def __init__(
248        self,
249        frame_type: FrameType,
250        c_r: int,
251        dlci: int,
252        p_f: int,
253        information: bytes = b'',
254        with_credits: bool = False,
255    ) -> None:
256        self.type = frame_type
257        self.c_r = c_r
258        self.dlci = dlci
259        self.p_f = p_f
260        self.information = information
261        length = len(information)
262        if with_credits:
263            length -= 1
264        if length > 0x7F:
265            # 2-byte length indicator
266            self.length = bytes([(length & 0x7F) << 1, (length >> 7) & 0xFF])
267        else:
268            # 1-byte length indicator
269            self.length = bytes([(length << 1) | 1])
270        self.address = (dlci << 2) | (c_r << 1) | 1
271        self.control = frame_type | (p_f << 4)
272        if frame_type == FrameType.UIH:
273            self.fcs = compute_fcs(bytes([self.address, self.control]))
274        else:
275            self.fcs = compute_fcs(bytes([self.address, self.control]) + self.length)
276
277    @staticmethod
278    def parse_mcc(data) -> Tuple[int, bool, bytes]:
279        mcc_type = data[0] >> 2
280        c_r = bool((data[0] >> 1) & 1)
281        length = data[1]
282        if data[1] & 1:
283            length >>= 1
284            value = data[2:]
285        else:
286            length = (data[3] << 7) & (length >> 1)
287            value = data[3 : 3 + length]
288
289        return (mcc_type, c_r, value)
290
291    @staticmethod
292    def make_mcc(mcc_type: int, c_r: int, data: bytes) -> bytes:
293        return (
294            bytes([(mcc_type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1])
295            + data
296        )
297
298    @staticmethod
299    def sabm(c_r: int, dlci: int):
300        return RFCOMM_Frame(FrameType.SABM, c_r, dlci, 1)
301
302    @staticmethod
303    def ua(c_r: int, dlci: int):
304        return RFCOMM_Frame(FrameType.UA, c_r, dlci, 1)
305
306    @staticmethod
307    def dm(c_r: int, dlci: int):
308        return RFCOMM_Frame(FrameType.DM, c_r, dlci, 1)
309
310    @staticmethod
311    def disc(c_r: int, dlci: int):
312        return RFCOMM_Frame(FrameType.DISC, c_r, dlci, 1)
313
314    @staticmethod
315    def uih(c_r: int, dlci: int, information: bytes, p_f: int = 0):
316        return RFCOMM_Frame(
317            FrameType.UIH, c_r, dlci, p_f, information, with_credits=(p_f == 1)
318        )
319
320    @staticmethod
321    def from_bytes(data: bytes) -> RFCOMM_Frame:
322        # Extract fields
323        dlci = (data[0] >> 2) & 0x3F
324        c_r = (data[0] >> 1) & 0x01
325        frame_type = FrameType(data[1] & 0xEF)
326        p_f = (data[1] >> 4) & 0x01
327        length = data[2]
328        if length & 0x01:
329            length >>= 1
330            information = data[3:-1]
331        else:
332            length = (data[3] << 7) & (length >> 1)
333            information = data[4:-1]
334        fcs = data[-1]
335
336        # Construct the frame and check the CRC
337        frame = RFCOMM_Frame(frame_type, c_r, dlci, p_f, information)
338        if frame.fcs != fcs:
339            logger.warning(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}')
340            raise InvalidPacketError('fcs mismatch')
341
342        return frame
343
344    def __bytes__(self) -> bytes:
345        return (
346            bytes([self.address, self.control])
347            + self.length
348            + self.information
349            + bytes([self.fcs])
350        )
351
352    def __str__(self) -> str:
353        return (
354            f'{color(self.type.name, "yellow")}'
355            f'(c/r={self.c_r},'
356            f'dlci={self.dlci},'
357            f'p/f={self.p_f},'
358            f'length={len(self.information)},'
359            f'fcs=0x{self.fcs:02X})'
360        )
361
362
363# -----------------------------------------------------------------------------
364@dataclasses.dataclass
365class RFCOMM_MCC_PN:
366    dlci: int
367    cl: int
368    priority: int
369    ack_timer: int
370    max_frame_size: int
371    max_retransmissions: int
372    initial_credits: int
373
374    def __post_init__(self) -> None:
375        if self.initial_credits < 1 or self.initial_credits > 7:
376            logger.warning(
377                f'Initial credits {self.initial_credits} is out of range [1, 7].'
378            )
379
380    @staticmethod
381    def from_bytes(data: bytes) -> RFCOMM_MCC_PN:
382        return RFCOMM_MCC_PN(
383            dlci=data[0],
384            cl=data[1],
385            priority=data[2],
386            ack_timer=data[3],
387            max_frame_size=data[4] | data[5] << 8,
388            max_retransmissions=data[6],
389            initial_credits=data[7] & 0x07,
390        )
391
392    def __bytes__(self) -> bytes:
393        return bytes(
394            [
395                self.dlci & 0xFF,
396                self.cl & 0xFF,
397                self.priority & 0xFF,
398                self.ack_timer & 0xFF,
399                self.max_frame_size & 0xFF,
400                (self.max_frame_size >> 8) & 0xFF,
401                self.max_retransmissions & 0xFF,
402                # Only 3 bits are meaningful.
403                self.initial_credits & 0x07,
404            ]
405        )
406
407
408# -----------------------------------------------------------------------------
409@dataclasses.dataclass
410class RFCOMM_MCC_MSC:
411    dlci: int
412    fc: int
413    rtc: int
414    rtr: int
415    ic: int
416    dv: int
417
418    @staticmethod
419    def from_bytes(data: bytes) -> RFCOMM_MCC_MSC:
420        return RFCOMM_MCC_MSC(
421            dlci=data[0] >> 2,
422            fc=data[1] >> 1 & 1,
423            rtc=data[1] >> 2 & 1,
424            rtr=data[1] >> 3 & 1,
425            ic=data[1] >> 6 & 1,
426            dv=data[1] >> 7 & 1,
427        )
428
429    def __bytes__(self) -> bytes:
430        return bytes(
431            [
432                (self.dlci << 2) | 3,
433                1
434                | self.fc << 1
435                | self.rtc << 2
436                | self.rtr << 3
437                | self.ic << 6
438                | self.dv << 7,
439            ]
440        )
441
442
443# -----------------------------------------------------------------------------
444class DLC(EventEmitter):
445    class State(enum.IntEnum):
446        INIT = 0x00
447        CONNECTING = 0x01
448        CONNECTED = 0x02
449        DISCONNECTING = 0x03
450        DISCONNECTED = 0x04
451        RESET = 0x05
452
453    def __init__(
454        self,
455        multiplexer: Multiplexer,
456        dlci: int,
457        tx_max_frame_size: int,
458        tx_initial_credits: int,
459        rx_max_frame_size: int,
460        rx_initial_credits: int,
461    ) -> None:
462        super().__init__()
463        self.multiplexer = multiplexer
464        self.dlci = dlci
465        self.rx_max_frame_size = rx_max_frame_size
466        self.rx_initial_credits = rx_initial_credits
467        self.rx_max_credits = RFCOMM_DEFAULT_MAX_CREDITS
468        self.rx_credits = rx_initial_credits
469        self.rx_credits_threshold = RFCOMM_DEFAULT_CREDIT_THRESHOLD
470        self.tx_max_frame_size = tx_max_frame_size
471        self.tx_credits = tx_initial_credits
472        self.tx_buffer = b''
473        self.state = DLC.State.INIT
474        self.role = multiplexer.role
475        self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0
476        self.connection_result: Optional[asyncio.Future] = None
477        self.disconnection_result: Optional[asyncio.Future] = None
478        self.drained = asyncio.Event()
479        self.drained.set()
480        # Queued packets when sink is not set.
481        self._enqueued_rx_packets: collections.deque[bytes] = collections.deque(
482            maxlen=DEFAULT_RX_QUEUE_SIZE
483        )
484        self._sink: Optional[Callable[[bytes], None]] = None
485
486        # Compute the MTU
487        max_overhead = 4 + 1  # header with 2-byte length + fcs
488        self.mtu = min(
489            tx_max_frame_size, self.multiplexer.l2cap_channel.peer_mtu - max_overhead
490        )
491
492    @property
493    def sink(self) -> Optional[Callable[[bytes], None]]:
494        return self._sink
495
496    @sink.setter
497    def sink(self, sink: Optional[Callable[[bytes], None]]) -> None:
498        self._sink = sink
499        # Dump queued packets to sink
500        if sink:
501            for packet in self._enqueued_rx_packets:
502                sink(packet)  # pylint: disable=not-callable
503            self._enqueued_rx_packets.clear()
504
505    def change_state(self, new_state: State) -> None:
506        logger.debug(f'{self} state change -> {color(new_state.name, "magenta")}')
507        self.state = new_state
508
509    def send_frame(self, frame: RFCOMM_Frame) -> None:
510        self.multiplexer.send_frame(frame)
511
512    def on_frame(self, frame: RFCOMM_Frame) -> None:
513        handler = getattr(self, f'on_{frame.type.name}_frame'.lower())
514        handler(frame)
515
516    def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
517        if self.state != DLC.State.CONNECTING:
518            logger.warning(
519                color('!!! received SABM when not in CONNECTING state', 'red')
520            )
521            return
522
523        self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci))
524
525        # Exchange the modem status with the peer
526        msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
527        mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=1, data=bytes(msc))
528        logger.debug(f'>>> MCC MSC Command: {msc}')
529        self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
530
531        self.change_state(DLC.State.CONNECTED)
532        self.emit('open')
533
534    def on_ua_frame(self, _frame: RFCOMM_Frame) -> None:
535        if self.state == DLC.State.CONNECTING:
536            # Exchange the modem status with the peer
537            msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
538            mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=1, data=bytes(msc))
539            logger.debug(f'>>> MCC MSC Command: {msc}')
540            self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
541
542            self.change_state(DLC.State.CONNECTED)
543            if self.connection_result:
544                self.connection_result.set_result(None)
545                self.connection_result = None
546            self.multiplexer.on_dlc_open_complete(self)
547        elif self.state == DLC.State.DISCONNECTING:
548            self.change_state(DLC.State.DISCONNECTED)
549            if self.disconnection_result:
550                self.disconnection_result.set_result(None)
551                self.disconnection_result = None
552            self.multiplexer.on_dlc_disconnection(self)
553            self.emit('close')
554        else:
555            logger.warning(
556                color(
557                    (
558                        '!!! received UA frame when not in '
559                        'CONNECTING or DISCONNECTING state'
560                    ),
561                    'red',
562                )
563            )
564
565    def on_dm_frame(self, frame: RFCOMM_Frame) -> None:
566        # TODO: handle all states
567        pass
568
569    def on_disc_frame(self, _frame: RFCOMM_Frame) -> None:
570        # TODO: handle all states
571        self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci))
572
573    def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
574        data = frame.information
575        if frame.p_f == 1:
576            # With credits
577            received_credits = frame.information[0]
578            self.tx_credits += received_credits
579
580            logger.debug(
581                f'<<< Credits [{self.dlci}]: '
582                f'received {received_credits}, total={self.tx_credits}'
583            )
584            data = data[1:]
585
586        logger.debug(
587            f'{color("<<< Data", "yellow")} '
588            f'[{self.dlci}] {len(data)} bytes, '
589            f'rx_credits={self.rx_credits}: {data.hex()}'
590        )
591        if data:
592            if self._sink:
593                self._sink(data)  # pylint: disable=not-callable
594            else:
595                self._enqueued_rx_packets.append(data)
596            if (
597                self._enqueued_rx_packets.maxlen
598                and len(self._enqueued_rx_packets) >= self._enqueued_rx_packets.maxlen
599            ):
600                logger.warning(f'DLC [{self.dlci}] received packet queue is full')
601
602            # Update the credits
603            if self.rx_credits > 0:
604                self.rx_credits -= 1
605            else:
606                logger.warning(color('!!! received frame with no rx credits', 'red'))
607
608        # Check if there's anything to send (including credits)
609        self.process_tx()
610
611    def on_ui_frame(self, frame: RFCOMM_Frame) -> None:
612        pass
613
614    def on_mcc_msc(self, c_r: bool, msc: RFCOMM_MCC_MSC) -> None:
615        if c_r:
616            # Command
617            logger.debug(f'<<< MCC MSC Command: {msc}')
618            msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
619            mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=0, data=bytes(msc))
620            logger.debug(f'>>> MCC MSC Response: {msc}')
621            self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
622        else:
623            # Response
624            logger.debug(f'<<< MCC MSC Response: {msc}')
625
626    def connect(self) -> None:
627        if self.state != DLC.State.INIT:
628            raise InvalidStateError('invalid state')
629
630        self.change_state(DLC.State.CONNECTING)
631        self.connection_result = asyncio.get_running_loop().create_future()
632        self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci))
633
634    async def disconnect(self) -> None:
635        if self.state != DLC.State.CONNECTED:
636            raise InvalidStateError('invalid state')
637
638        self.disconnection_result = asyncio.get_running_loop().create_future()
639        self.change_state(DLC.State.DISCONNECTING)
640        self.send_frame(
641            RFCOMM_Frame.disc(
642                c_r=1 if self.role == Multiplexer.Role.INITIATOR else 0, dlci=self.dlci
643            )
644        )
645        await self.disconnection_result
646
647    def accept(self) -> None:
648        if self.state != DLC.State.INIT:
649            raise InvalidStateError('invalid state')
650
651        pn = RFCOMM_MCC_PN(
652            dlci=self.dlci,
653            cl=0xE0,
654            priority=7,
655            ack_timer=0,
656            max_frame_size=self.rx_max_frame_size,
657            max_retransmissions=0,
658            initial_credits=self.rx_initial_credits,
659        )
660        mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.PN, c_r=0, data=bytes(pn))
661        logger.debug(f'>>> PN Response: {pn}')
662        self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
663        self.change_state(DLC.State.CONNECTING)
664
665    def rx_credits_needed(self) -> int:
666        if self.rx_credits <= self.rx_credits_threshold:
667            return self.rx_max_credits - self.rx_credits
668
669        return 0
670
671    def process_tx(self) -> None:
672        # Send anything we can (or an empty frame if we need to send rx credits)
673        rx_credits_needed = self.rx_credits_needed()
674        while (self.tx_buffer and self.tx_credits > 0) or rx_credits_needed > 0:
675            # Get the next chunk, up to MTU size
676            if rx_credits_needed > 0:
677                chunk = bytes([rx_credits_needed]) + self.tx_buffer[: self.mtu - 1]
678                self.tx_buffer = self.tx_buffer[len(chunk) - 1 :]
679                self.rx_credits += rx_credits_needed
680                tx_credit_spent = len(chunk) > 1
681            else:
682                chunk = self.tx_buffer[: self.mtu]
683                self.tx_buffer = self.tx_buffer[len(chunk) :]
684                tx_credit_spent = True
685
686            # Update the tx credits
687            # (no tx credit spent for empty frames that only contain rx credits)
688            if tx_credit_spent:
689                self.tx_credits -= 1
690
691            # Send the frame
692            logger.debug(
693                f'>>> sending {len(chunk)} bytes with {rx_credits_needed} credits, '
694                f'rx_credits={self.rx_credits}, '
695                f'tx_credits={self.tx_credits}'
696            )
697            self.send_frame(
698                RFCOMM_Frame.uih(
699                    c_r=self.c_r,
700                    dlci=self.dlci,
701                    information=chunk,
702                    p_f=1 if rx_credits_needed > 0 else 0,
703                )
704            )
705
706            rx_credits_needed = 0
707            if not self.tx_buffer:
708                self.drained.set()
709
710    # Stream protocol
711    def write(self, data: Union[bytes, str]) -> None:
712        # We can only send bytes
713        if not isinstance(data, bytes):
714            if isinstance(data, str):
715                # Automatically convert strings to bytes using UTF-8
716                data = data.encode('utf-8')
717            else:
718                raise InvalidArgumentError('write only accept bytes or strings')
719
720        self.tx_buffer += data
721        self.drained.clear()
722        self.process_tx()
723
724    async def drain(self) -> None:
725        await self.drained.wait()
726
727    def abort(self) -> None:
728        logger.debug(f'aborting DLC: {self}')
729        if self.connection_result:
730            self.connection_result.cancel()
731            self.connection_result = None
732        if self.disconnection_result:
733            self.disconnection_result.cancel()
734            self.disconnection_result = None
735        self.change_state(DLC.State.RESET)
736        self.emit('close')
737
738    def __str__(self) -> str:
739        return (
740            f'DLC(dlci={self.dlci}, '
741            f'state={self.state.name}, '
742            f'rx_max_frame_size={self.rx_max_frame_size}, '
743            f'rx_credits={self.rx_credits}, '
744            f'rx_max_credits={self.rx_max_credits}, '
745            f'tx_max_frame_size={self.tx_max_frame_size}, '
746            f'tx_credits={self.tx_credits}'
747            ')'
748        )
749
750
751# -----------------------------------------------------------------------------
752class Multiplexer(EventEmitter):
753    class Role(enum.IntEnum):
754        INITIATOR = 0x00
755        RESPONDER = 0x01
756
757    class State(enum.IntEnum):
758        INIT = 0x00
759        CONNECTING = 0x01
760        CONNECTED = 0x02
761        OPENING = 0x03
762        DISCONNECTING = 0x04
763        DISCONNECTED = 0x05
764        RESET = 0x06
765
766    connection_result: Optional[asyncio.Future]
767    disconnection_result: Optional[asyncio.Future]
768    open_result: Optional[asyncio.Future]
769    acceptor: Optional[Callable[[int], Optional[Tuple[int, int]]]]
770    dlcs: Dict[int, DLC]
771
772    def __init__(self, l2cap_channel: l2cap.ClassicChannel, role: Role) -> None:
773        super().__init__()
774        self.role = role
775        self.l2cap_channel = l2cap_channel
776        self.state = Multiplexer.State.INIT
777        self.dlcs = {}  # DLCs, by DLCI
778        self.connection_result = None
779        self.disconnection_result = None
780        self.open_result = None
781        self.open_pn: Optional[RFCOMM_MCC_PN] = None
782        self.open_rx_max_credits = 0
783        self.acceptor = None
784
785        # Become a sink for the L2CAP channel
786        l2cap_channel.sink = self.on_pdu
787
788        l2cap_channel.on('close', self.on_l2cap_channel_close)
789
790    def change_state(self, new_state: State) -> None:
791        logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
792        self.state = new_state
793
794    def send_frame(self, frame: RFCOMM_Frame) -> None:
795        logger.debug(f'>>> Multiplexer sending {frame}')
796        self.l2cap_channel.send_pdu(frame)
797
798    def on_pdu(self, pdu: bytes) -> None:
799        frame = RFCOMM_Frame.from_bytes(pdu)
800        logger.debug(f'<<< Multiplexer received {frame}')
801
802        # Dispatch to this multiplexer or to a dlc, depending on the address
803        if frame.dlci == 0:
804            self.on_frame(frame)
805        else:
806            if frame.type == FrameType.DM:
807                # DM responses are for a DLCI, but since we only create the dlc when we
808                # receive a PN response (because we need the parameters), we handle DM
809                # frames at the Multiplexer level
810                self.on_dm_frame(frame)
811            else:
812                dlc = self.dlcs.get(frame.dlci)
813                if dlc is None:
814                    logger.warning(f'no dlc for DLCI {frame.dlci}')
815                    return
816                dlc.on_frame(frame)
817
818    def on_frame(self, frame: RFCOMM_Frame) -> None:
819        handler = getattr(self, f'on_{frame.type.name}_frame'.lower())
820        handler(frame)
821
822    def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
823        if self.state != Multiplexer.State.INIT:
824            logger.debug('not in INIT state, ignoring SABM')
825            return
826        self.change_state(Multiplexer.State.CONNECTED)
827        self.send_frame(RFCOMM_Frame.ua(c_r=1, dlci=0))
828
829    def on_ua_frame(self, _frame: RFCOMM_Frame) -> None:
830        if self.state == Multiplexer.State.CONNECTING:
831            self.change_state(Multiplexer.State.CONNECTED)
832            if self.connection_result:
833                self.connection_result.set_result(0)
834                self.connection_result = None
835        elif self.state == Multiplexer.State.DISCONNECTING:
836            self.change_state(Multiplexer.State.DISCONNECTED)
837            if self.disconnection_result:
838                self.disconnection_result.set_result(None)
839                self.disconnection_result = None
840
841    def on_dm_frame(self, _frame: RFCOMM_Frame) -> None:
842        if self.state == Multiplexer.State.OPENING:
843            self.change_state(Multiplexer.State.CONNECTED)
844            if self.open_result:
845                self.open_result.set_exception(
846                    core.ConnectionError(
847                        core.ConnectionError.CONNECTION_REFUSED,
848                        BT_BR_EDR_TRANSPORT,
849                        self.l2cap_channel.connection.peer_address,
850                        'rfcomm',
851                    )
852                )
853                self.open_result = None
854        else:
855            logger.warning(f'unexpected state for DM: {self}')
856
857    def on_disc_frame(self, _frame: RFCOMM_Frame) -> None:
858        self.change_state(Multiplexer.State.DISCONNECTED)
859        self.send_frame(
860            RFCOMM_Frame.ua(
861                c_r=0 if self.role == Multiplexer.Role.INITIATOR else 1, dlci=0
862            )
863        )
864
865    def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
866        (mcc_type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
867
868        if mcc_type == MccType.PN:
869            pn = RFCOMM_MCC_PN.from_bytes(value)
870            self.on_mcc_pn(c_r, pn)
871        elif mcc_type == MccType.MSC:
872            mcs = RFCOMM_MCC_MSC.from_bytes(value)
873            self.on_mcc_msc(c_r, mcs)
874
875    def on_ui_frame(self, frame: RFCOMM_Frame) -> None:
876        pass
877
878    def on_mcc_pn(self, c_r: bool, pn: RFCOMM_MCC_PN) -> None:
879        if c_r:
880            # Command
881            logger.debug(f'<<< PN Command: {pn}')
882
883            # Check with the multiplexer if there's an acceptor for this channel
884            if pn.dlci & 1:
885                # Not expected, this is an initiator-side number
886                # TODO: error out
887                logger.warning(f'invalid DLCI: {pn.dlci}')
888            else:
889                if self.acceptor:
890                    channel_number = pn.dlci >> 1
891                    if dlc_params := self.acceptor(channel_number):
892                        # Create a new DLC
893                        dlc = DLC(
894                            self,
895                            dlci=pn.dlci,
896                            tx_max_frame_size=pn.max_frame_size,
897                            tx_initial_credits=pn.initial_credits,
898                            rx_max_frame_size=dlc_params[0],
899                            rx_initial_credits=dlc_params[1],
900                        )
901                        self.dlcs[pn.dlci] = dlc
902
903                        # Re-emit the handshake completion event
904                        dlc.on('open', lambda: self.emit('dlc', dlc))
905
906                        # Respond to complete the handshake
907                        dlc.accept()
908                    else:
909                        # No acceptor, we're in Disconnected Mode
910                        self.send_frame(RFCOMM_Frame.dm(c_r=1, dlci=pn.dlci))
911                else:
912                    # No acceptor?? shouldn't happen
913                    logger.warning(color('!!! no acceptor registered', 'red'))
914        else:
915            # Response
916            logger.debug(f'>>> PN Response: {pn}')
917            if self.state == Multiplexer.State.OPENING:
918                assert self.open_pn
919                dlc = DLC(
920                    self,
921                    dlci=pn.dlci,
922                    tx_max_frame_size=pn.max_frame_size,
923                    tx_initial_credits=pn.initial_credits,
924                    rx_max_frame_size=self.open_pn.max_frame_size,
925                    rx_initial_credits=self.open_pn.initial_credits,
926                )
927                self.dlcs[pn.dlci] = dlc
928                self.open_pn = None
929                dlc.connect()
930            else:
931                logger.warning('ignoring PN response')
932
933    def on_mcc_msc(self, c_r: bool, msc: RFCOMM_MCC_MSC) -> None:
934        dlc = self.dlcs.get(msc.dlci)
935        if dlc is None:
936            logger.warning(f'no dlc for DLCI {msc.dlci}')
937            return
938        dlc.on_mcc_msc(c_r, msc)
939
940    async def connect(self) -> None:
941        if self.state != Multiplexer.State.INIT:
942            raise InvalidStateError('invalid state')
943
944        self.change_state(Multiplexer.State.CONNECTING)
945        self.connection_result = asyncio.get_running_loop().create_future()
946        self.send_frame(RFCOMM_Frame.sabm(c_r=1, dlci=0))
947        return await self.connection_result
948
949    async def disconnect(self) -> None:
950        if self.state != Multiplexer.State.CONNECTED:
951            return
952
953        self.disconnection_result = asyncio.get_running_loop().create_future()
954        self.change_state(Multiplexer.State.DISCONNECTING)
955        self.send_frame(
956            RFCOMM_Frame.disc(
957                c_r=1 if self.role == Multiplexer.Role.INITIATOR else 0, dlci=0
958            )
959        )
960        await self.disconnection_result
961
962    async def open_dlc(
963        self,
964        channel: int,
965        max_frame_size: int = RFCOMM_DEFAULT_MAX_FRAME_SIZE,
966        initial_credits: int = RFCOMM_DEFAULT_INITIAL_CREDITS,
967    ) -> DLC:
968        if self.state != Multiplexer.State.CONNECTED:
969            if self.state == Multiplexer.State.OPENING:
970                raise InvalidStateError('open already in progress')
971
972            raise InvalidStateError('not connected')
973
974        self.open_pn = RFCOMM_MCC_PN(
975            dlci=channel << 1,
976            cl=0xF0,
977            priority=7,
978            ack_timer=0,
979            max_frame_size=max_frame_size,
980            max_retransmissions=0,
981            initial_credits=initial_credits,
982        )
983        mcc = RFCOMM_Frame.make_mcc(
984            mcc_type=MccType.PN, c_r=1, data=bytes(self.open_pn)
985        )
986        logger.debug(f'>>> Sending MCC: {self.open_pn}')
987        self.open_result = asyncio.get_running_loop().create_future()
988        self.change_state(Multiplexer.State.OPENING)
989        self.send_frame(
990            RFCOMM_Frame.uih(
991                c_r=1 if self.role == Multiplexer.Role.INITIATOR else 0,
992                dlci=0,
993                information=mcc,
994            )
995        )
996        return await self.open_result
997
998    def on_dlc_open_complete(self, dlc: DLC) -> None:
999        logger.debug(f'DLC [{dlc.dlci}] open complete')
1000
1001        self.change_state(Multiplexer.State.CONNECTED)
1002
1003        if self.open_result:
1004            self.open_result.set_result(dlc)
1005            self.open_result = None
1006
1007    def on_dlc_disconnection(self, dlc: DLC) -> None:
1008        logger.debug(f'DLC [{dlc.dlci}] disconnection')
1009        self.dlcs.pop(dlc.dlci, None)
1010
1011    def on_l2cap_channel_close(self) -> None:
1012        logger.debug('L2CAP channel closed, cleaning up')
1013        if self.open_result:
1014            self.open_result.cancel()
1015            self.open_result = None
1016        if self.disconnection_result:
1017            self.disconnection_result.cancel()
1018            self.disconnection_result = None
1019        for dlc in self.dlcs.values():
1020            dlc.abort()
1021
1022    def __str__(self) -> str:
1023        return f'Multiplexer(state={self.state.name})'
1024
1025
1026# -----------------------------------------------------------------------------
1027class Client:
1028    multiplexer: Optional[Multiplexer]
1029    l2cap_channel: Optional[l2cap.ClassicChannel]
1030
1031    def __init__(
1032        self, connection: Connection, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU
1033    ) -> None:
1034        self.connection = connection
1035        self.l2cap_mtu = l2cap_mtu
1036        self.l2cap_channel = None
1037        self.multiplexer = None
1038
1039    async def start(self) -> Multiplexer:
1040        # Create a new L2CAP connection
1041        try:
1042            self.l2cap_channel = await self.connection.create_l2cap_channel(
1043                spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=self.l2cap_mtu)
1044            )
1045        except ProtocolError as error:
1046            logger.warning(f'L2CAP connection failed: {error}')
1047            raise
1048
1049        assert self.l2cap_channel is not None
1050        # Create a multiplexer to manage DLCs with the server
1051        self.multiplexer = Multiplexer(self.l2cap_channel, Multiplexer.Role.INITIATOR)
1052
1053        # Connect the multiplexer
1054        await self.multiplexer.connect()
1055
1056        return self.multiplexer
1057
1058    async def shutdown(self) -> None:
1059        if self.multiplexer is None:
1060            return
1061        # Disconnect the multiplexer
1062        await self.multiplexer.disconnect()
1063        self.multiplexer = None
1064
1065        # Close the L2CAP channel
1066        if self.l2cap_channel:
1067            await self.l2cap_channel.disconnect()
1068            self.l2cap_channel = None
1069
1070    async def __aenter__(self) -> Multiplexer:
1071        return await self.start()
1072
1073    async def __aexit__(self, *args) -> None:
1074        await self.shutdown()
1075
1076
1077# -----------------------------------------------------------------------------
1078class Server(EventEmitter):
1079    def __init__(
1080        self, device: Device, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU
1081    ) -> None:
1082        super().__init__()
1083        self.device = device
1084        self.acceptors: Dict[int, Callable[[DLC], None]] = {}
1085        self.dlc_configs: Dict[int, Tuple[int, int]] = {}
1086
1087        # Register ourselves with the L2CAP channel manager
1088        self.l2cap_server = device.create_l2cap_server(
1089            spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=l2cap_mtu),
1090            handler=self.on_connection,
1091        )
1092
1093    def listen(
1094        self,
1095        acceptor: Callable[[DLC], None],
1096        channel: int = 0,
1097        max_frame_size: int = RFCOMM_DEFAULT_MAX_FRAME_SIZE,
1098        initial_credits: int = RFCOMM_DEFAULT_INITIAL_CREDITS,
1099    ) -> int:
1100        if channel:
1101            if channel in self.acceptors:
1102                # Busy
1103                return 0
1104        else:
1105            # Find a free channel number
1106            for candidate in range(
1107                RFCOMM_DYNAMIC_CHANNEL_NUMBER_START,
1108                RFCOMM_DYNAMIC_CHANNEL_NUMBER_END + 1,
1109            ):
1110                if candidate not in self.acceptors:
1111                    channel = candidate
1112                    break
1113
1114            if channel == 0:
1115                # All channels used...
1116                return 0
1117
1118        self.acceptors[channel] = acceptor
1119        self.dlc_configs[channel] = (max_frame_size, initial_credits)
1120
1121        return channel
1122
1123    def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
1124        logger.debug(f'+++ new L2CAP connection: {l2cap_channel}')
1125        l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
1126
1127    def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
1128        logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
1129
1130        # Create a new multiplexer for the channel
1131        multiplexer = Multiplexer(l2cap_channel, Multiplexer.Role.RESPONDER)
1132        multiplexer.acceptor = self.accept_dlc
1133        multiplexer.on('dlc', self.on_dlc)
1134
1135        # Notify
1136        self.emit('start', multiplexer)
1137
1138    def accept_dlc(self, channel_number: int) -> Optional[Tuple[int, int]]:
1139        return self.dlc_configs.get(channel_number)
1140
1141    def on_dlc(self, dlc: DLC) -> None:
1142        logger.debug(f'@@@ new DLC connected: {dlc}')
1143
1144        # Let the acceptor know
1145        if acceptor := self.acceptors.get(dlc.dlci >> 1):
1146            acceptor(dlc)
1147
1148    def __enter__(self) -> Self:
1149        return self
1150
1151    def __exit__(self, *args) -> None:
1152        self.l2cap_server.close()
1153