1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18import asyncio 19import logging 20import threading 21import time 22 23import usb.core 24import usb.util 25 26from typing import Optional, Set 27from usb.core import Device as UsbDevice 28from usb.core import USBError 29from usb.util import CTRL_TYPE_CLASS, CTRL_RECIPIENT_OTHER 30from usb.legacy import REQ_SET_FEATURE, REQ_CLEAR_FEATURE, CLASS_HUB 31 32from .common import Transport, ParserSource, TransportInitError 33from .. import hci 34from ..colors import color 35 36 37# ----------------------------------------------------------------------------- 38# Constant 39# ----------------------------------------------------------------------------- 40USB_PORT_FEATURE_POWER = 8 41POWER_CYCLE_DELAY = 1 42RESET_DELAY = 3 43 44# ----------------------------------------------------------------------------- 45# Logging 46# ----------------------------------------------------------------------------- 47logger = logging.getLogger(__name__) 48 49# ----------------------------------------------------------------------------- 50# Global 51# ----------------------------------------------------------------------------- 52devices_in_use: Set[int] = set() 53 54 55# ----------------------------------------------------------------------------- 56async def open_pyusb_transport(spec: str) -> Transport: 57 ''' 58 Open a USB transport. [Implementation based on PyUSB] 59 The parameter string has this syntax: 60 either <index> or <vendor>:<product> 61 With <index> as the 0-based index to select amongst all the devices that appear 62 to be supporting Bluetooth HCI (0 being the first one), or 63 Where <vendor> and <product> are the vendor ID and product ID in hexadecimal. 64 65 Examples: 66 0 --> the first BT USB dongle 67 04b4:f901 --> the BT USB dongle with vendor=04b4 and product=f901 68 ''' 69 70 # pylint: disable=invalid-name 71 USB_RECIPIENT_DEVICE = 0x00 72 USB_REQUEST_TYPE_CLASS = 0x01 << 5 73 USB_ENDPOINT_EVENTS_IN = 0x81 74 USB_ENDPOINT_ACL_IN = 0x82 75 USB_ENDPOINT_SCO_IN = 0x83 76 USB_ENDPOINT_ACL_OUT = 0x02 77 # USB_ENDPOINT_SCO_OUT = 0x03 78 USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0 79 USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01 80 USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01 81 82 READ_SIZE = 1024 83 READ_TIMEOUT = 1000 84 85 class UsbPacketSink: 86 def __init__(self, device): 87 self.device = device 88 self.thread = threading.Thread(target=self.run) 89 self.loop = asyncio.get_running_loop() 90 self.stop_event = None 91 92 def on_packet(self, packet): 93 # TODO: don't block here, just queue for the write thread 94 if len(packet) == 0: 95 logger.warning('packet too short') 96 return 97 98 packet_type = packet[0] 99 try: 100 if packet_type == hci.HCI_ACL_DATA_PACKET: 101 self.device.write(USB_ENDPOINT_ACL_OUT, packet[1:]) 102 elif packet_type == hci.HCI_COMMAND_PACKET: 103 self.device.ctrl_transfer( 104 USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS, 105 0, 106 0, 107 0, 108 packet[1:], 109 ) 110 else: 111 logger.warning( 112 color(f'unsupported packet type {packet_type}', 'red') 113 ) 114 except usb.core.USBTimeoutError: 115 logger.warning('USB Write Timeout') 116 except usb.core.USBError as error: 117 logger.warning(f'USB write error: {error}') 118 time.sleep(1) # Sleep one second to avoid busy looping 119 120 def start(self): 121 self.thread.start() 122 123 async def stop(self): 124 # Create stop events and wait for them to be signaled 125 self.stop_event = asyncio.Event() 126 await self.stop_event.wait() 127 128 def run(self): 129 while self.stop_event is None: 130 time.sleep(1) 131 self.loop.call_soon_threadsafe(self.stop_event.set) 132 133 class UsbPacketSource(asyncio.Protocol, ParserSource): 134 def __init__(self, device, metadata, sco_enabled): 135 super().__init__() 136 self.device = device 137 self.metadata = metadata 138 self.loop = asyncio.get_running_loop() 139 self.queue = asyncio.Queue() 140 self.dequeue_task = None 141 self.event_thread = threading.Thread( 142 target=self.run, args=(USB_ENDPOINT_EVENTS_IN, hci.HCI_EVENT_PACKET) 143 ) 144 self.event_thread.stop_event = None 145 self.acl_thread = threading.Thread( 146 target=self.run, args=(USB_ENDPOINT_ACL_IN, hci.HCI_ACL_DATA_PACKET) 147 ) 148 self.acl_thread.stop_event = None 149 150 # SCO support is optional 151 self.sco_enabled = sco_enabled 152 if sco_enabled: 153 self.sco_thread = threading.Thread( 154 target=self.run, 155 args=(USB_ENDPOINT_SCO_IN, hci.HCI_SYNCHRONOUS_DATA_PACKET), 156 ) 157 self.sco_thread.stop_event = None 158 159 def data_received(self, data): 160 self.parser.feed_data(data) 161 162 def enqueue(self, packet): 163 self.queue.put_nowait(packet) 164 165 async def dequeue(self): 166 while True: 167 try: 168 data = await self.queue.get() 169 except asyncio.CancelledError: 170 return 171 self.data_received(data) 172 173 def start(self): 174 self.dequeue_task = self.loop.create_task(self.dequeue()) 175 self.event_thread.start() 176 self.acl_thread.start() 177 if self.sco_enabled: 178 self.sco_thread.start() 179 180 async def stop(self): 181 # Stop the dequeuing task 182 self.dequeue_task.cancel() 183 184 # Create stop events and wait for them to be signaled 185 self.event_thread.stop_event = asyncio.Event() 186 self.acl_thread.stop_event = asyncio.Event() 187 await self.event_thread.stop_event.wait() 188 await self.acl_thread.stop_event.wait() 189 if self.sco_enabled: 190 await self.sco_thread.stop_event.wait() 191 192 def run(self, endpoint, packet_type): 193 # Read until asked to stop 194 current_thread = threading.current_thread() 195 while current_thread.stop_event is None: 196 try: 197 # Read, with a timeout of 1 second 198 data = self.device.read(endpoint, READ_SIZE, timeout=READ_TIMEOUT) 199 packet = bytes([packet_type]) + data.tobytes() 200 self.loop.call_soon_threadsafe(self.enqueue, packet) 201 except usb.core.USBTimeoutError: 202 continue 203 except usb.core.USBError: 204 # Don't log this: because pyusb doesn't really support multiple 205 # threads reading at the same time, we can get occasional 206 # USBError(errno=5) Input/Output errors reported, but they seem to 207 # be harmless. 208 # Until support for async or multi-thread support is added to pyusb, 209 # we'll just live with this as is... 210 # logger.warning(f'USB read error: {error}') 211 time.sleep(1) # Sleep one second to avoid busy looping 212 213 stop_event = current_thread.stop_event 214 self.loop.call_soon_threadsafe(stop_event.set) 215 216 class UsbTransport(Transport): 217 def __init__(self, device, source, sink): 218 super().__init__(source, sink) 219 self.device = device 220 221 async def close(self): 222 await self.source.stop() 223 await self.sink.stop() 224 usb.util.release_interface(self.device, 0) 225 if devices_in_use and device.address in devices_in_use: 226 devices_in_use.remove(device.address) 227 228 usb_find = usb.core.find 229 try: 230 import libusb_package 231 except ImportError: 232 logger.debug('libusb_package is not available') 233 else: 234 usb_find = libusb_package.find 235 236 # Find the device according to the spec moniker 237 power_cycle = False 238 if spec.startswith('!'): 239 power_cycle = True 240 spec = spec[1:] 241 if ':' in spec: 242 vendor_id, product_id = spec.split(':') 243 device = None 244 devices = usb_find( 245 find_all=True, idVendor=int(vendor_id, 16), idProduct=int(product_id, 16) 246 ) 247 for d in devices: 248 if d.address in devices_in_use: 249 continue 250 device = d 251 devices_in_use.add(d.address) 252 break 253 if device is None: 254 raise ValueError('device already in use') 255 elif '-' in spec: 256 257 def device_path(device): 258 if device.port_numbers: 259 return f'{device.bus}-{".".join(map(str, device.port_numbers))}' 260 else: 261 return str(device.bus) 262 263 device = usb_find(custom_match=lambda device: device_path(device) == spec) 264 else: 265 device_index = int(spec) 266 devices = list( 267 usb_find( 268 find_all=1, 269 bDeviceClass=USB_DEVICE_CLASS_WIRELESS_CONTROLLER, 270 bDeviceSubClass=USB_DEVICE_SUBCLASS_RF_CONTROLLER, 271 bDeviceProtocol=USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER, 272 ) 273 ) 274 if len(devices) > device_index: 275 device = devices[device_index] 276 else: 277 device = None 278 279 if device is None: 280 raise TransportInitError('device not found') 281 logger.debug(f'USB Device: {device}') 282 283 # Power Cycle the device 284 if power_cycle: 285 try: 286 device = await _power_cycle(device) # type: ignore 287 except Exception as e: 288 logging.debug(e) 289 logging.info(f"Unable to power cycle {hex(device.idVendor)} {hex(device.idProduct)}") # type: ignore 290 291 # Collect the metadata 292 device_metadata = {'vendor_id': device.idVendor, 'product_id': device.idProduct} 293 294 # Detach the kernel driver if needed 295 if device.is_kernel_driver_active(0): 296 logger.debug("detaching kernel driver") 297 device.detach_kernel_driver(0) 298 299 # Set configuration, if needed 300 try: 301 configuration = device.get_active_configuration() 302 except usb.core.USBError: 303 device.set_configuration() 304 configuration = device.get_active_configuration() 305 interface = configuration[(0, 0)] 306 logger.debug(f'USB Interface: {interface}') 307 usb.util.claim_interface(device, 0) 308 309 # Select an alternate setting for SCO, if available 310 sco_enabled = False 311 # pylint: disable=line-too-long 312 # NOTE: this is disabled for now, because SCO with alternate settings is broken, 313 # see: https://github.com/libusb/libusb/issues/36 314 # 315 # best_packet_size = 0 316 # best_interface = None 317 # sco_enabled = False 318 # for interface in configuration: 319 # iso_in_endpoint = None 320 # iso_out_endpoint = None 321 # for endpoint in interface: 322 # if (endpoint.bEndpointAddress == USB_ENDPOINT_SCO_IN and 323 # usb.util.endpoint_direction(endpoint.bEndpointAddress) == usb.util.ENDPOINT_IN and 324 # usb.util.endpoint_type(endpoint.bmAttributes) == usb.util.ENDPOINT_TYPE_ISO): 325 # iso_in_endpoint = endpoint 326 # continue 327 # if (endpoint.bEndpointAddress == USB_ENDPOINT_SCO_OUT and 328 # usb.util.endpoint_direction(endpoint.bEndpointAddress) == usb.util.ENDPOINT_OUT and 329 # usb.util.endpoint_type(endpoint.bmAttributes) == usb.util.ENDPOINT_TYPE_ISO): 330 # iso_out_endpoint = endpoint 331 332 # if iso_in_endpoint is not None and iso_out_endpoint is not None: 333 # if iso_out_endpoint.wMaxPacketSize > best_packet_size: 334 # best_packet_size = iso_out_endpoint.wMaxPacketSize 335 # best_interface = interface 336 337 # if best_interface is not None: 338 # logger.debug(f'SCO enabled, selecting alternate setting (wMaxPacketSize={best_packet_size}): {best_interface}') 339 # sco_enabled = True 340 # try: 341 # device.set_interface_altsetting( 342 # interface = best_interface.bInterfaceNumber, 343 # alternate_setting = best_interface.bAlternateSetting 344 # ) 345 # except usb.USBError: 346 # logger.warning('failed to set alternate setting') 347 348 packet_source = UsbPacketSource(device, device_metadata, sco_enabled) 349 packet_sink = UsbPacketSink(device) 350 packet_source.start() 351 packet_sink.start() 352 353 return UsbTransport(device, packet_source, packet_sink) 354 355 356async def _power_cycle(device: UsbDevice) -> UsbDevice: 357 """ 358 For devices connected to compatible USB hubs: Performs a power cycle on a given USB device. 359 This involves temporarily disabling its port on the hub and then re-enabling it. 360 """ 361 device_path = f'{device.bus}-{".".join(map(str, device.port_numbers))}' # type: ignore 362 hub = _find_hub_by_device_path(device_path) 363 364 if hub: 365 try: 366 device_port = device.port_numbers[-1] # type: ignore 367 _set_port_status(hub, device_port, False) 368 await asyncio.sleep(POWER_CYCLE_DELAY) 369 _set_port_status(hub, device_port, True) 370 await asyncio.sleep(RESET_DELAY) 371 372 # Device needs to be find again otherwise it will appear as disconnected 373 return usb.core.find(idVendor=device.idVendor, idProduct=device.idProduct) # type: ignore 374 except USBError as e: 375 logger.error(f"Adjustment needed: Please revise the udev rule for device {hex(device.idVendor)}:{hex(device.idProduct)} for proper recognition.") # type: ignore 376 logger.error(e) 377 378 return device 379 380 381def _set_port_status(device: UsbDevice, port: int, on: bool): 382 """Sets the power status of a specific port on a USB hub.""" 383 device.ctrl_transfer( 384 bmRequestType=CTRL_TYPE_CLASS | CTRL_RECIPIENT_OTHER, 385 bRequest=REQ_SET_FEATURE if on else REQ_CLEAR_FEATURE, 386 wIndex=port, 387 wValue=USB_PORT_FEATURE_POWER, 388 ) 389 390 391def _find_device_by_path(sys_path: str) -> Optional[UsbDevice]: 392 """Finds a USB device based on its system path.""" 393 bus_num, *port_parts = sys_path.split('-') 394 ports = [int(port) for port in port_parts[0].split('.')] 395 devices = usb.core.find(find_all=True, bus=int(bus_num)) 396 if devices: 397 for device in devices: 398 if device.bus == int(bus_num) and list(device.port_numbers) == ports: # type: ignore 399 return device 400 401 return None 402 403 404def _find_hub_by_device_path(sys_path: str) -> Optional[UsbDevice]: 405 """Finds the USB hub associated with a specific device path.""" 406 hub_sys_path = sys_path.rsplit('.', 1)[0] 407 hub_device = _find_device_by_path(hub_sys_path) 408 409 if hub_device is None: 410 return None 411 else: 412 return hub_device if _is_hub(hub_device) else None 413 414 415def _is_hub(device: UsbDevice) -> bool: 416 """Checks if a USB device is a hub""" 417 if device.bDeviceClass == CLASS_HUB: # type: ignore 418 return True 419 for config in device: 420 for interface in config: 421 if interface.bInterfaceClass == CLASS_HUB: # type: ignore 422 return True 423 return False 424