1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18from __future__ import annotations 19from dataclasses import dataclass 20import logging 21import enum 22import struct 23 24from abc import ABC, abstractmethod 25from pyee import EventEmitter 26from typing import Optional, Callable 27from typing_extensions import override 28 29from bumble import l2cap, device 30from bumble.core import InvalidStateError, ProtocolError 31from bumble.hci import Address 32 33 34# ----------------------------------------------------------------------------- 35# Logging 36# ----------------------------------------------------------------------------- 37logger = logging.getLogger(__name__) 38 39 40# ----------------------------------------------------------------------------- 41# Constants 42# ----------------------------------------------------------------------------- 43# fmt: on 44HID_CONTROL_PSM = 0x0011 45HID_INTERRUPT_PSM = 0x0013 46 47 48class Message: 49 message_type: MessageType 50 51 # Report types 52 class ReportType(enum.IntEnum): 53 OTHER_REPORT = 0x00 54 INPUT_REPORT = 0x01 55 OUTPUT_REPORT = 0x02 56 FEATURE_REPORT = 0x03 57 58 # Handshake parameters 59 class Handshake(enum.IntEnum): 60 SUCCESSFUL = 0x00 61 NOT_READY = 0x01 62 ERR_INVALID_REPORT_ID = 0x02 63 ERR_UNSUPPORTED_REQUEST = 0x03 64 ERR_INVALID_PARAMETER = 0x04 65 ERR_UNKNOWN = 0x0E 66 ERR_FATAL = 0x0F 67 68 # Message Type 69 class MessageType(enum.IntEnum): 70 HANDSHAKE = 0x00 71 CONTROL = 0x01 72 GET_REPORT = 0x04 73 SET_REPORT = 0x05 74 GET_PROTOCOL = 0x06 75 SET_PROTOCOL = 0x07 76 DATA = 0x0A 77 78 # Protocol modes 79 class ProtocolMode(enum.IntEnum): 80 BOOT_PROTOCOL = 0x00 81 REPORT_PROTOCOL = 0x01 82 83 # Control Operations 84 class ControlCommand(enum.IntEnum): 85 SUSPEND = 0x03 86 EXIT_SUSPEND = 0x04 87 VIRTUAL_CABLE_UNPLUG = 0x05 88 89 # Class Method to derive header 90 @classmethod 91 def header(cls, lower_bits: int = 0x00) -> bytes: 92 return bytes([(cls.message_type << 4) | lower_bits]) 93 94 95# HIDP messages 96@dataclass 97class GetReportMessage(Message): 98 report_type: int 99 report_id: int 100 buffer_size: int 101 message_type = Message.MessageType.GET_REPORT 102 103 def __bytes__(self) -> bytes: 104 packet_bytes = bytearray() 105 packet_bytes.append(self.report_id) 106 if self.buffer_size == 0: 107 return self.header(self.report_type) + packet_bytes 108 else: 109 return ( 110 self.header(0x08 | self.report_type) 111 + packet_bytes 112 + struct.pack("<H", self.buffer_size) 113 ) 114 115 116@dataclass 117class SetReportMessage(Message): 118 report_type: int 119 data: bytes 120 message_type = Message.MessageType.SET_REPORT 121 122 def __bytes__(self) -> bytes: 123 return self.header(self.report_type) + self.data 124 125 126@dataclass 127class SendControlData(Message): 128 report_type: int 129 data: bytes 130 message_type = Message.MessageType.DATA 131 132 def __bytes__(self) -> bytes: 133 return self.header(self.report_type) + self.data 134 135 136@dataclass 137class GetProtocolMessage(Message): 138 message_type = Message.MessageType.GET_PROTOCOL 139 140 def __bytes__(self) -> bytes: 141 return self.header() 142 143 144@dataclass 145class SetProtocolMessage(Message): 146 protocol_mode: int 147 message_type = Message.MessageType.SET_PROTOCOL 148 149 def __bytes__(self) -> bytes: 150 return self.header(self.protocol_mode) 151 152 153@dataclass 154class Suspend(Message): 155 message_type = Message.MessageType.CONTROL 156 157 def __bytes__(self) -> bytes: 158 return self.header(Message.ControlCommand.SUSPEND) 159 160 161@dataclass 162class ExitSuspend(Message): 163 message_type = Message.MessageType.CONTROL 164 165 def __bytes__(self) -> bytes: 166 return self.header(Message.ControlCommand.EXIT_SUSPEND) 167 168 169@dataclass 170class VirtualCableUnplug(Message): 171 message_type = Message.MessageType.CONTROL 172 173 def __bytes__(self) -> bytes: 174 return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG) 175 176 177# Device sends input report, host sends output report. 178@dataclass 179class SendData(Message): 180 data: bytes 181 report_type: int 182 message_type = Message.MessageType.DATA 183 184 def __bytes__(self) -> bytes: 185 return self.header(self.report_type) + self.data 186 187 188@dataclass 189class SendHandshakeMessage(Message): 190 result_code: int 191 message_type = Message.MessageType.HANDSHAKE 192 193 def __bytes__(self) -> bytes: 194 return self.header(self.result_code) 195 196 197# ----------------------------------------------------------------------------- 198class HID(ABC, EventEmitter): 199 l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None 200 l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None 201 connection: Optional[device.Connection] = None 202 203 class Role(enum.IntEnum): 204 HOST = 0x00 205 DEVICE = 0x01 206 207 def __init__(self, device: device.Device, role: Role) -> None: 208 super().__init__() 209 self.remote_device_bd_address: Optional[Address] = None 210 self.device = device 211 self.role = role 212 213 # Register ourselves with the L2CAP channel manager 214 device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection) 215 device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection) 216 217 device.on('connection', self.on_device_connection) 218 219 async def connect_control_channel(self) -> None: 220 # Create a new L2CAP connection - control channel 221 try: 222 channel = await self.device.l2cap_channel_manager.connect( 223 self.connection, HID_CONTROL_PSM 224 ) 225 channel.sink = self.on_ctrl_pdu 226 self.l2cap_ctrl_channel = channel 227 except ProtocolError: 228 logging.exception(f'L2CAP connection failed.') 229 raise 230 231 async def connect_interrupt_channel(self) -> None: 232 # Create a new L2CAP connection - interrupt channel 233 try: 234 channel = await self.device.l2cap_channel_manager.connect( 235 self.connection, HID_INTERRUPT_PSM 236 ) 237 channel.sink = self.on_intr_pdu 238 self.l2cap_intr_channel = channel 239 except ProtocolError: 240 logging.exception(f'L2CAP connection failed.') 241 raise 242 243 async def disconnect_interrupt_channel(self) -> None: 244 if self.l2cap_intr_channel is None: 245 raise InvalidStateError('invalid state') 246 channel = self.l2cap_intr_channel 247 self.l2cap_intr_channel = None 248 await channel.disconnect() 249 250 async def disconnect_control_channel(self) -> None: 251 if self.l2cap_ctrl_channel is None: 252 raise InvalidStateError('invalid state') 253 channel = self.l2cap_ctrl_channel 254 self.l2cap_ctrl_channel = None 255 await channel.disconnect() 256 257 def on_device_connection(self, connection: device.Connection) -> None: 258 self.connection = connection 259 self.remote_device_bd_address = connection.peer_address 260 connection.on('disconnection', self.on_device_disconnection) 261 262 def on_device_disconnection(self, reason: int) -> None: 263 self.connection = None 264 265 def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None: 266 logger.debug(f'+++ New L2CAP connection: {l2cap_channel}') 267 l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel)) 268 l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel)) 269 270 def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None: 271 if l2cap_channel.psm == HID_CONTROL_PSM: 272 self.l2cap_ctrl_channel = l2cap_channel 273 self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu 274 else: 275 self.l2cap_intr_channel = l2cap_channel 276 self.l2cap_intr_channel.sink = self.on_intr_pdu 277 logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}') 278 279 def on_l2cap_channel_close(self, l2cap_channel: l2cap.ClassicChannel) -> None: 280 if l2cap_channel.psm == HID_CONTROL_PSM: 281 self.l2cap_ctrl_channel = None 282 else: 283 self.l2cap_intr_channel = None 284 logger.debug(f'$$$ L2CAP channel close: {l2cap_channel}') 285 286 @abstractmethod 287 def on_ctrl_pdu(self, pdu: bytes) -> None: 288 pass 289 290 def on_intr_pdu(self, pdu: bytes) -> None: 291 logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}') 292 self.emit("interrupt_data", pdu) 293 294 def send_pdu_on_ctrl(self, msg: bytes) -> None: 295 assert self.l2cap_ctrl_channel 296 self.l2cap_ctrl_channel.send_pdu(msg) 297 298 def send_pdu_on_intr(self, msg: bytes) -> None: 299 assert self.l2cap_intr_channel 300 self.l2cap_intr_channel.send_pdu(msg) 301 302 def send_data(self, data: bytes) -> None: 303 if self.role == HID.Role.HOST: 304 report_type = Message.ReportType.OUTPUT_REPORT 305 else: 306 report_type = Message.ReportType.INPUT_REPORT 307 msg = SendData(data, report_type) 308 hid_message = bytes(msg) 309 if self.l2cap_intr_channel is not None: 310 logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}') 311 self.send_pdu_on_intr(hid_message) 312 313 def virtual_cable_unplug(self) -> None: 314 msg = VirtualCableUnplug() 315 hid_message = bytes(msg) 316 logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}') 317 self.send_pdu_on_ctrl(hid_message) 318 319 320# ----------------------------------------------------------------------------- 321 322 323class Device(HID): 324 class GetSetReturn(enum.IntEnum): 325 FAILURE = 0x00 326 REPORT_ID_NOT_FOUND = 0x01 327 ERR_UNSUPPORTED_REQUEST = 0x02 328 ERR_UNKNOWN = 0x03 329 ERR_INVALID_PARAMETER = 0x04 330 SUCCESS = 0xFF 331 332 @dataclass 333 class GetSetStatus: 334 data: bytes = b'' 335 status: int = 0 336 337 get_report_cb: Optional[Callable[[int, int, int], GetSetStatus]] = None 338 set_report_cb: Optional[Callable[[int, int, int, bytes], GetSetStatus]] = None 339 get_protocol_cb: Optional[Callable[[], GetSetStatus]] = None 340 set_protocol_cb: Optional[Callable[[int], GetSetStatus]] = None 341 342 def __init__(self, device: device.Device) -> None: 343 super().__init__(device, HID.Role.DEVICE) 344 345 @override 346 def on_ctrl_pdu(self, pdu: bytes) -> None: 347 logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}') 348 param = pdu[0] & 0x0F 349 message_type = pdu[0] >> 4 350 351 if message_type == Message.MessageType.GET_REPORT: 352 logger.debug('<<< HID GET REPORT') 353 self.handle_get_report(pdu) 354 elif message_type == Message.MessageType.SET_REPORT: 355 logger.debug('<<< HID SET REPORT') 356 self.handle_set_report(pdu) 357 elif message_type == Message.MessageType.GET_PROTOCOL: 358 logger.debug('<<< HID GET PROTOCOL') 359 self.handle_get_protocol(pdu) 360 elif message_type == Message.MessageType.SET_PROTOCOL: 361 logger.debug('<<< HID SET PROTOCOL') 362 self.handle_set_protocol(pdu) 363 elif message_type == Message.MessageType.DATA: 364 logger.debug('<<< HID CONTROL DATA') 365 self.emit('control_data', pdu) 366 elif message_type == Message.MessageType.CONTROL: 367 if param == Message.ControlCommand.SUSPEND: 368 logger.debug('<<< HID SUSPEND') 369 self.emit('suspend') 370 elif param == Message.ControlCommand.EXIT_SUSPEND: 371 logger.debug('<<< HID EXIT SUSPEND') 372 self.emit('exit_suspend') 373 elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG: 374 logger.debug('<<< HID VIRTUAL CABLE UNPLUG') 375 self.emit('virtual_cable_unplug') 376 else: 377 logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED') 378 else: 379 logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED') 380 self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) 381 382 def send_handshake_message(self, result_code: int) -> None: 383 msg = SendHandshakeMessage(result_code) 384 hid_message = bytes(msg) 385 logger.debug(f'>>> HID HANDSHAKE MESSAGE, PDU: {hid_message.hex()}') 386 self.send_pdu_on_ctrl(hid_message) 387 388 def send_control_data(self, report_type: int, data: bytes): 389 msg = SendControlData(report_type=report_type, data=data) 390 hid_message = bytes(msg) 391 logger.debug(f'>>> HID CONTROL DATA: {hid_message.hex()}') 392 self.send_pdu_on_ctrl(hid_message) 393 394 def handle_get_report(self, pdu: bytes): 395 if self.get_report_cb is None: 396 logger.debug("GetReport callback not registered !!") 397 self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) 398 return 399 report_type = pdu[0] & 0x03 400 buffer_flag = (pdu[0] & 0x08) >> 3 401 report_id = pdu[1] 402 logger.debug(f"buffer_flag: {buffer_flag}") 403 if buffer_flag == 1: 404 buffer_size = (pdu[3] << 8) | pdu[2] 405 else: 406 buffer_size = 0 407 408 ret = self.get_report_cb(report_id, report_type, buffer_size) 409 if ret.status == self.GetSetReturn.FAILURE: 410 self.send_handshake_message(Message.Handshake.ERR_UNKNOWN) 411 elif ret.status == self.GetSetReturn.SUCCESS: 412 data = bytearray() 413 data.append(report_id) 414 data.extend(ret.data) 415 if len(data) < self.l2cap_ctrl_channel.peer_mtu: # type: ignore[union-attr] 416 self.send_control_data(report_type=report_type, data=data) 417 else: 418 self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) 419 elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND: 420 self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID) 421 elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER: 422 self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) 423 elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST: 424 self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) 425 426 def register_get_report_cb( 427 self, cb: Callable[[int, int, int], Device.GetSetStatus] 428 ) -> None: 429 self.get_report_cb = cb 430 logger.debug("GetReport callback registered successfully") 431 432 def handle_set_report(self, pdu: bytes): 433 if self.set_report_cb is None: 434 logger.debug("SetReport callback not registered !!") 435 self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) 436 return 437 report_type = pdu[0] & 0x03 438 report_id = pdu[1] 439 report_data = pdu[2:] 440 report_size = len(report_data) + 1 441 ret = self.set_report_cb(report_id, report_type, report_size, report_data) 442 if ret.status == self.GetSetReturn.SUCCESS: 443 self.send_handshake_message(Message.Handshake.SUCCESSFUL) 444 elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER: 445 self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) 446 elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND: 447 self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID) 448 else: 449 self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) 450 451 def register_set_report_cb( 452 self, cb: Callable[[int, int, int, bytes], Device.GetSetStatus] 453 ) -> None: 454 self.set_report_cb = cb 455 logger.debug("SetReport callback registered successfully") 456 457 def handle_get_protocol(self, pdu: bytes): 458 if self.get_protocol_cb is None: 459 logger.debug("GetProtocol callback not registered !!") 460 self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) 461 return 462 ret = self.get_protocol_cb() 463 if ret.status == self.GetSetReturn.SUCCESS: 464 self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data) 465 else: 466 self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) 467 468 def register_get_protocol_cb(self, cb: Callable[[], Device.GetSetStatus]) -> None: 469 self.get_protocol_cb = cb 470 logger.debug("GetProtocol callback registered successfully") 471 472 def handle_set_protocol(self, pdu: bytes): 473 if self.set_protocol_cb is None: 474 logger.debug("SetProtocol callback not registered !!") 475 self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) 476 return 477 ret = self.set_protocol_cb(pdu[0] & 0x01) 478 if ret.status == self.GetSetReturn.SUCCESS: 479 self.send_handshake_message(Message.Handshake.SUCCESSFUL) 480 else: 481 self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) 482 483 def register_set_protocol_cb( 484 self, cb: Callable[[int], Device.GetSetStatus] 485 ) -> None: 486 self.set_protocol_cb = cb 487 logger.debug("SetProtocol callback registered successfully") 488 489 490# ----------------------------------------------------------------------------- 491class Host(HID): 492 def __init__(self, device: device.Device) -> None: 493 super().__init__(device, HID.Role.HOST) 494 495 def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None: 496 msg = GetReportMessage( 497 report_type=report_type, report_id=report_id, buffer_size=buffer_size 498 ) 499 hid_message = bytes(msg) 500 logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}') 501 self.send_pdu_on_ctrl(hid_message) 502 503 def set_report(self, report_type: int, data: bytes) -> None: 504 msg = SetReportMessage(report_type=report_type, data=data) 505 hid_message = bytes(msg) 506 logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}') 507 self.send_pdu_on_ctrl(hid_message) 508 509 def get_protocol(self) -> None: 510 msg = GetProtocolMessage() 511 hid_message = bytes(msg) 512 logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}') 513 self.send_pdu_on_ctrl(hid_message) 514 515 def set_protocol(self, protocol_mode: int) -> None: 516 msg = SetProtocolMessage(protocol_mode=protocol_mode) 517 hid_message = bytes(msg) 518 logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}') 519 self.send_pdu_on_ctrl(hid_message) 520 521 def suspend(self) -> None: 522 msg = Suspend() 523 hid_message = bytes(msg) 524 logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}') 525 self.send_pdu_on_ctrl(hid_message) 526 527 def exit_suspend(self) -> None: 528 msg = ExitSuspend() 529 hid_message = bytes(msg) 530 logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}') 531 self.send_pdu_on_ctrl(hid_message) 532 533 @override 534 def on_ctrl_pdu(self, pdu: bytes) -> None: 535 logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}') 536 param = pdu[0] & 0x0F 537 message_type = pdu[0] >> 4 538 if message_type == Message.MessageType.HANDSHAKE: 539 logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}') 540 self.emit('handshake', Message.Handshake(param)) 541 elif message_type == Message.MessageType.DATA: 542 logger.debug('<<< HID CONTROL DATA') 543 self.emit('control_data', pdu) 544 elif message_type == Message.MessageType.CONTROL: 545 if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG: 546 logger.debug('<<< HID VIRTUAL CABLE UNPLUG') 547 self.emit('virtual_cable_unplug') 548 else: 549 logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED') 550 else: 551 logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED') 552