1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19from dataclasses import dataclass
20import logging
21import enum
22import struct
23
24from abc import ABC, abstractmethod
25from pyee import EventEmitter
26from typing import Optional, Callable
27from typing_extensions import override
28
29from bumble import l2cap, device
30from bumble.core import InvalidStateError, ProtocolError
31from bumble.hci import Address
32
33
34# -----------------------------------------------------------------------------
35# Logging
36# -----------------------------------------------------------------------------
37logger = logging.getLogger(__name__)
38
39
40# -----------------------------------------------------------------------------
41# Constants
42# -----------------------------------------------------------------------------
43# fmt: on
44HID_CONTROL_PSM = 0x0011
45HID_INTERRUPT_PSM = 0x0013
46
47
48class Message:
49    message_type: MessageType
50
51    # Report types
52    class ReportType(enum.IntEnum):
53        OTHER_REPORT = 0x00
54        INPUT_REPORT = 0x01
55        OUTPUT_REPORT = 0x02
56        FEATURE_REPORT = 0x03
57
58    # Handshake parameters
59    class Handshake(enum.IntEnum):
60        SUCCESSFUL = 0x00
61        NOT_READY = 0x01
62        ERR_INVALID_REPORT_ID = 0x02
63        ERR_UNSUPPORTED_REQUEST = 0x03
64        ERR_INVALID_PARAMETER = 0x04
65        ERR_UNKNOWN = 0x0E
66        ERR_FATAL = 0x0F
67
68    # Message Type
69    class MessageType(enum.IntEnum):
70        HANDSHAKE = 0x00
71        CONTROL = 0x01
72        GET_REPORT = 0x04
73        SET_REPORT = 0x05
74        GET_PROTOCOL = 0x06
75        SET_PROTOCOL = 0x07
76        DATA = 0x0A
77
78    # Protocol modes
79    class ProtocolMode(enum.IntEnum):
80        BOOT_PROTOCOL = 0x00
81        REPORT_PROTOCOL = 0x01
82
83    # Control Operations
84    class ControlCommand(enum.IntEnum):
85        SUSPEND = 0x03
86        EXIT_SUSPEND = 0x04
87        VIRTUAL_CABLE_UNPLUG = 0x05
88
89    # Class Method to derive header
90    @classmethod
91    def header(cls, lower_bits: int = 0x00) -> bytes:
92        return bytes([(cls.message_type << 4) | lower_bits])
93
94
95# HIDP messages
96@dataclass
97class GetReportMessage(Message):
98    report_type: int
99    report_id: int
100    buffer_size: int
101    message_type = Message.MessageType.GET_REPORT
102
103    def __bytes__(self) -> bytes:
104        packet_bytes = bytearray()
105        packet_bytes.append(self.report_id)
106        if self.buffer_size == 0:
107            return self.header(self.report_type) + packet_bytes
108        else:
109            return (
110                self.header(0x08 | self.report_type)
111                + packet_bytes
112                + struct.pack("<H", self.buffer_size)
113            )
114
115
116@dataclass
117class SetReportMessage(Message):
118    report_type: int
119    data: bytes
120    message_type = Message.MessageType.SET_REPORT
121
122    def __bytes__(self) -> bytes:
123        return self.header(self.report_type) + self.data
124
125
126@dataclass
127class SendControlData(Message):
128    report_type: int
129    data: bytes
130    message_type = Message.MessageType.DATA
131
132    def __bytes__(self) -> bytes:
133        return self.header(self.report_type) + self.data
134
135
136@dataclass
137class GetProtocolMessage(Message):
138    message_type = Message.MessageType.GET_PROTOCOL
139
140    def __bytes__(self) -> bytes:
141        return self.header()
142
143
144@dataclass
145class SetProtocolMessage(Message):
146    protocol_mode: int
147    message_type = Message.MessageType.SET_PROTOCOL
148
149    def __bytes__(self) -> bytes:
150        return self.header(self.protocol_mode)
151
152
153@dataclass
154class Suspend(Message):
155    message_type = Message.MessageType.CONTROL
156
157    def __bytes__(self) -> bytes:
158        return self.header(Message.ControlCommand.SUSPEND)
159
160
161@dataclass
162class ExitSuspend(Message):
163    message_type = Message.MessageType.CONTROL
164
165    def __bytes__(self) -> bytes:
166        return self.header(Message.ControlCommand.EXIT_SUSPEND)
167
168
169@dataclass
170class VirtualCableUnplug(Message):
171    message_type = Message.MessageType.CONTROL
172
173    def __bytes__(self) -> bytes:
174        return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG)
175
176
177# Device sends input report, host sends output report.
178@dataclass
179class SendData(Message):
180    data: bytes
181    report_type: int
182    message_type = Message.MessageType.DATA
183
184    def __bytes__(self) -> bytes:
185        return self.header(self.report_type) + self.data
186
187
188@dataclass
189class SendHandshakeMessage(Message):
190    result_code: int
191    message_type = Message.MessageType.HANDSHAKE
192
193    def __bytes__(self) -> bytes:
194        return self.header(self.result_code)
195
196
197# -----------------------------------------------------------------------------
198class HID(ABC, EventEmitter):
199    l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None
200    l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None
201    connection: Optional[device.Connection] = None
202
203    class Role(enum.IntEnum):
204        HOST = 0x00
205        DEVICE = 0x01
206
207    def __init__(self, device: device.Device, role: Role) -> None:
208        super().__init__()
209        self.remote_device_bd_address: Optional[Address] = None
210        self.device = device
211        self.role = role
212
213        # Register ourselves with the L2CAP channel manager
214        device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection)
215        device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection)
216
217        device.on('connection', self.on_device_connection)
218
219    async def connect_control_channel(self) -> None:
220        # Create a new L2CAP connection - control channel
221        try:
222            channel = await self.device.l2cap_channel_manager.connect(
223                self.connection, HID_CONTROL_PSM
224            )
225            channel.sink = self.on_ctrl_pdu
226            self.l2cap_ctrl_channel = channel
227        except ProtocolError:
228            logging.exception(f'L2CAP connection failed.')
229            raise
230
231    async def connect_interrupt_channel(self) -> None:
232        # Create a new L2CAP connection - interrupt channel
233        try:
234            channel = await self.device.l2cap_channel_manager.connect(
235                self.connection, HID_INTERRUPT_PSM
236            )
237            channel.sink = self.on_intr_pdu
238            self.l2cap_intr_channel = channel
239        except ProtocolError:
240            logging.exception(f'L2CAP connection failed.')
241            raise
242
243    async def disconnect_interrupt_channel(self) -> None:
244        if self.l2cap_intr_channel is None:
245            raise InvalidStateError('invalid state')
246        channel = self.l2cap_intr_channel
247        self.l2cap_intr_channel = None
248        await channel.disconnect()
249
250    async def disconnect_control_channel(self) -> None:
251        if self.l2cap_ctrl_channel is None:
252            raise InvalidStateError('invalid state')
253        channel = self.l2cap_ctrl_channel
254        self.l2cap_ctrl_channel = None
255        await channel.disconnect()
256
257    def on_device_connection(self, connection: device.Connection) -> None:
258        self.connection = connection
259        self.remote_device_bd_address = connection.peer_address
260        connection.on('disconnection', self.on_device_disconnection)
261
262    def on_device_disconnection(self, reason: int) -> None:
263        self.connection = None
264
265    def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
266        logger.debug(f'+++ New L2CAP connection: {l2cap_channel}')
267        l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
268        l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel))
269
270    def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
271        if l2cap_channel.psm == HID_CONTROL_PSM:
272            self.l2cap_ctrl_channel = l2cap_channel
273            self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
274        else:
275            self.l2cap_intr_channel = l2cap_channel
276            self.l2cap_intr_channel.sink = self.on_intr_pdu
277        logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
278
279    def on_l2cap_channel_close(self, l2cap_channel: l2cap.ClassicChannel) -> None:
280        if l2cap_channel.psm == HID_CONTROL_PSM:
281            self.l2cap_ctrl_channel = None
282        else:
283            self.l2cap_intr_channel = None
284        logger.debug(f'$$$ L2CAP channel close: {l2cap_channel}')
285
286    @abstractmethod
287    def on_ctrl_pdu(self, pdu: bytes) -> None:
288        pass
289
290    def on_intr_pdu(self, pdu: bytes) -> None:
291        logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
292        self.emit("interrupt_data", pdu)
293
294    def send_pdu_on_ctrl(self, msg: bytes) -> None:
295        assert self.l2cap_ctrl_channel
296        self.l2cap_ctrl_channel.send_pdu(msg)
297
298    def send_pdu_on_intr(self, msg: bytes) -> None:
299        assert self.l2cap_intr_channel
300        self.l2cap_intr_channel.send_pdu(msg)
301
302    def send_data(self, data: bytes) -> None:
303        if self.role == HID.Role.HOST:
304            report_type = Message.ReportType.OUTPUT_REPORT
305        else:
306            report_type = Message.ReportType.INPUT_REPORT
307        msg = SendData(data, report_type)
308        hid_message = bytes(msg)
309        if self.l2cap_intr_channel is not None:
310            logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
311            self.send_pdu_on_intr(hid_message)
312
313    def virtual_cable_unplug(self) -> None:
314        msg = VirtualCableUnplug()
315        hid_message = bytes(msg)
316        logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}')
317        self.send_pdu_on_ctrl(hid_message)
318
319
320# -----------------------------------------------------------------------------
321
322
323class Device(HID):
324    class GetSetReturn(enum.IntEnum):
325        FAILURE = 0x00
326        REPORT_ID_NOT_FOUND = 0x01
327        ERR_UNSUPPORTED_REQUEST = 0x02
328        ERR_UNKNOWN = 0x03
329        ERR_INVALID_PARAMETER = 0x04
330        SUCCESS = 0xFF
331
332    @dataclass
333    class GetSetStatus:
334        data: bytes = b''
335        status: int = 0
336
337    get_report_cb: Optional[Callable[[int, int, int], GetSetStatus]] = None
338    set_report_cb: Optional[Callable[[int, int, int, bytes], GetSetStatus]] = None
339    get_protocol_cb: Optional[Callable[[], GetSetStatus]] = None
340    set_protocol_cb: Optional[Callable[[int], GetSetStatus]] = None
341
342    def __init__(self, device: device.Device) -> None:
343        super().__init__(device, HID.Role.DEVICE)
344
345    @override
346    def on_ctrl_pdu(self, pdu: bytes) -> None:
347        logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
348        param = pdu[0] & 0x0F
349        message_type = pdu[0] >> 4
350
351        if message_type == Message.MessageType.GET_REPORT:
352            logger.debug('<<< HID GET REPORT')
353            self.handle_get_report(pdu)
354        elif message_type == Message.MessageType.SET_REPORT:
355            logger.debug('<<< HID SET REPORT')
356            self.handle_set_report(pdu)
357        elif message_type == Message.MessageType.GET_PROTOCOL:
358            logger.debug('<<< HID GET PROTOCOL')
359            self.handle_get_protocol(pdu)
360        elif message_type == Message.MessageType.SET_PROTOCOL:
361            logger.debug('<<< HID SET PROTOCOL')
362            self.handle_set_protocol(pdu)
363        elif message_type == Message.MessageType.DATA:
364            logger.debug('<<< HID CONTROL DATA')
365            self.emit('control_data', pdu)
366        elif message_type == Message.MessageType.CONTROL:
367            if param == Message.ControlCommand.SUSPEND:
368                logger.debug('<<< HID SUSPEND')
369                self.emit('suspend')
370            elif param == Message.ControlCommand.EXIT_SUSPEND:
371                logger.debug('<<< HID EXIT SUSPEND')
372                self.emit('exit_suspend')
373            elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
374                logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
375                self.emit('virtual_cable_unplug')
376            else:
377                logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
378        else:
379            logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
380            self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
381
382    def send_handshake_message(self, result_code: int) -> None:
383        msg = SendHandshakeMessage(result_code)
384        hid_message = bytes(msg)
385        logger.debug(f'>>> HID HANDSHAKE MESSAGE, PDU: {hid_message.hex()}')
386        self.send_pdu_on_ctrl(hid_message)
387
388    def send_control_data(self, report_type: int, data: bytes):
389        msg = SendControlData(report_type=report_type, data=data)
390        hid_message = bytes(msg)
391        logger.debug(f'>>> HID CONTROL DATA: {hid_message.hex()}')
392        self.send_pdu_on_ctrl(hid_message)
393
394    def handle_get_report(self, pdu: bytes):
395        if self.get_report_cb is None:
396            logger.debug("GetReport callback not registered !!")
397            self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
398            return
399        report_type = pdu[0] & 0x03
400        buffer_flag = (pdu[0] & 0x08) >> 3
401        report_id = pdu[1]
402        logger.debug(f"buffer_flag: {buffer_flag}")
403        if buffer_flag == 1:
404            buffer_size = (pdu[3] << 8) | pdu[2]
405        else:
406            buffer_size = 0
407
408        ret = self.get_report_cb(report_id, report_type, buffer_size)
409        if ret.status == self.GetSetReturn.FAILURE:
410            self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
411        elif ret.status == self.GetSetReturn.SUCCESS:
412            data = bytearray()
413            data.append(report_id)
414            data.extend(ret.data)
415            if len(data) < self.l2cap_ctrl_channel.peer_mtu:  # type: ignore[union-attr]
416                self.send_control_data(report_type=report_type, data=data)
417            else:
418                self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
419        elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
420            self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
421        elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
422            self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
423        elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
424            self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
425
426    def register_get_report_cb(
427        self, cb: Callable[[int, int, int], Device.GetSetStatus]
428    ) -> None:
429        self.get_report_cb = cb
430        logger.debug("GetReport callback registered successfully")
431
432    def handle_set_report(self, pdu: bytes):
433        if self.set_report_cb is None:
434            logger.debug("SetReport callback not registered !!")
435            self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
436            return
437        report_type = pdu[0] & 0x03
438        report_id = pdu[1]
439        report_data = pdu[2:]
440        report_size = len(report_data) + 1
441        ret = self.set_report_cb(report_id, report_type, report_size, report_data)
442        if ret.status == self.GetSetReturn.SUCCESS:
443            self.send_handshake_message(Message.Handshake.SUCCESSFUL)
444        elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
445            self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
446        elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
447            self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
448        else:
449            self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
450
451    def register_set_report_cb(
452        self, cb: Callable[[int, int, int, bytes], Device.GetSetStatus]
453    ) -> None:
454        self.set_report_cb = cb
455        logger.debug("SetReport callback registered successfully")
456
457    def handle_get_protocol(self, pdu: bytes):
458        if self.get_protocol_cb is None:
459            logger.debug("GetProtocol callback not registered !!")
460            self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
461            return
462        ret = self.get_protocol_cb()
463        if ret.status == self.GetSetReturn.SUCCESS:
464            self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
465        else:
466            self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
467
468    def register_get_protocol_cb(self, cb: Callable[[], Device.GetSetStatus]) -> None:
469        self.get_protocol_cb = cb
470        logger.debug("GetProtocol callback registered successfully")
471
472    def handle_set_protocol(self, pdu: bytes):
473        if self.set_protocol_cb is None:
474            logger.debug("SetProtocol callback not registered !!")
475            self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
476            return
477        ret = self.set_protocol_cb(pdu[0] & 0x01)
478        if ret.status == self.GetSetReturn.SUCCESS:
479            self.send_handshake_message(Message.Handshake.SUCCESSFUL)
480        else:
481            self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
482
483    def register_set_protocol_cb(
484        self, cb: Callable[[int], Device.GetSetStatus]
485    ) -> None:
486        self.set_protocol_cb = cb
487        logger.debug("SetProtocol callback registered successfully")
488
489
490# -----------------------------------------------------------------------------
491class Host(HID):
492    def __init__(self, device: device.Device) -> None:
493        super().__init__(device, HID.Role.HOST)
494
495    def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
496        msg = GetReportMessage(
497            report_type=report_type, report_id=report_id, buffer_size=buffer_size
498        )
499        hid_message = bytes(msg)
500        logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
501        self.send_pdu_on_ctrl(hid_message)
502
503    def set_report(self, report_type: int, data: bytes) -> None:
504        msg = SetReportMessage(report_type=report_type, data=data)
505        hid_message = bytes(msg)
506        logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
507        self.send_pdu_on_ctrl(hid_message)
508
509    def get_protocol(self) -> None:
510        msg = GetProtocolMessage()
511        hid_message = bytes(msg)
512        logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
513        self.send_pdu_on_ctrl(hid_message)
514
515    def set_protocol(self, protocol_mode: int) -> None:
516        msg = SetProtocolMessage(protocol_mode=protocol_mode)
517        hid_message = bytes(msg)
518        logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
519        self.send_pdu_on_ctrl(hid_message)
520
521    def suspend(self) -> None:
522        msg = Suspend()
523        hid_message = bytes(msg)
524        logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}')
525        self.send_pdu_on_ctrl(hid_message)
526
527    def exit_suspend(self) -> None:
528        msg = ExitSuspend()
529        hid_message = bytes(msg)
530        logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}')
531        self.send_pdu_on_ctrl(hid_message)
532
533    @override
534    def on_ctrl_pdu(self, pdu: bytes) -> None:
535        logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
536        param = pdu[0] & 0x0F
537        message_type = pdu[0] >> 4
538        if message_type == Message.MessageType.HANDSHAKE:
539            logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
540            self.emit('handshake', Message.Handshake(param))
541        elif message_type == Message.MessageType.DATA:
542            logger.debug('<<< HID CONTROL DATA')
543            self.emit('control_data', pdu)
544        elif message_type == Message.MessageType.CONTROL:
545            if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
546                logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
547                self.emit('virtual_cable_unplug')
548            else:
549                logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
550        else:
551            logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
552