1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import logging
20import threading
21import time
22
23import usb.core
24import usb.util
25
26from typing import Optional, Set
27from usb.core import Device as UsbDevice
28from usb.core import USBError
29from usb.util import CTRL_TYPE_CLASS, CTRL_RECIPIENT_OTHER
30from usb.legacy import REQ_SET_FEATURE, REQ_CLEAR_FEATURE, CLASS_HUB
31
32from .common import Transport, ParserSource, TransportInitError
33from .. import hci
34from ..colors import color
35
36
37# -----------------------------------------------------------------------------
38# Constant
39# -----------------------------------------------------------------------------
40USB_PORT_FEATURE_POWER = 8
41POWER_CYCLE_DELAY = 1
42RESET_DELAY = 3
43
44# -----------------------------------------------------------------------------
45# Logging
46# -----------------------------------------------------------------------------
47logger = logging.getLogger(__name__)
48
49# -----------------------------------------------------------------------------
50# Global
51# -----------------------------------------------------------------------------
52devices_in_use: Set[int] = set()
53
54
55# -----------------------------------------------------------------------------
56async def open_pyusb_transport(spec: str) -> Transport:
57    '''
58    Open a USB transport. [Implementation based on PyUSB]
59    The parameter string has this syntax:
60    either <index> or <vendor>:<product>
61    With <index> as the 0-based index to select amongst all the devices that appear
62    to be supporting Bluetooth HCI (0 being the first one), or
63    Where <vendor> and <product> are the vendor ID and product ID in hexadecimal.
64
65    Examples:
66    0 --> the first BT USB dongle
67    04b4:f901 --> the BT USB dongle with vendor=04b4 and product=f901
68    '''
69
70    # pylint: disable=invalid-name
71    USB_RECIPIENT_DEVICE = 0x00
72    USB_REQUEST_TYPE_CLASS = 0x01 << 5
73    USB_ENDPOINT_EVENTS_IN = 0x81
74    USB_ENDPOINT_ACL_IN = 0x82
75    USB_ENDPOINT_SCO_IN = 0x83
76    USB_ENDPOINT_ACL_OUT = 0x02
77    #  USB_ENDPOINT_SCO_OUT                             = 0x03
78    USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
79    USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
80    USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01
81
82    READ_SIZE = 1024
83    READ_TIMEOUT = 1000
84
85    class UsbPacketSink:
86        def __init__(self, device):
87            self.device = device
88            self.thread = threading.Thread(target=self.run)
89            self.loop = asyncio.get_running_loop()
90            self.stop_event = None
91
92        def on_packet(self, packet):
93            # TODO: don't block here, just queue for the write thread
94            if len(packet) == 0:
95                logger.warning('packet too short')
96                return
97
98            packet_type = packet[0]
99            try:
100                if packet_type == hci.HCI_ACL_DATA_PACKET:
101                    self.device.write(USB_ENDPOINT_ACL_OUT, packet[1:])
102                elif packet_type == hci.HCI_COMMAND_PACKET:
103                    self.device.ctrl_transfer(
104                        USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
105                        0,
106                        0,
107                        0,
108                        packet[1:],
109                    )
110                else:
111                    logger.warning(
112                        color(f'unsupported packet type {packet_type}', 'red')
113                    )
114            except usb.core.USBTimeoutError:
115                logger.warning('USB Write Timeout')
116            except usb.core.USBError as error:
117                logger.warning(f'USB write error: {error}')
118                time.sleep(1)  # Sleep one second to avoid busy looping
119
120        def start(self):
121            self.thread.start()
122
123        async def stop(self):
124            # Create stop events and wait for them to be signaled
125            self.stop_event = asyncio.Event()
126            await self.stop_event.wait()
127
128        def run(self):
129            while self.stop_event is None:
130                time.sleep(1)
131            self.loop.call_soon_threadsafe(self.stop_event.set)
132
133    class UsbPacketSource(asyncio.Protocol, ParserSource):
134        def __init__(self, device, metadata, sco_enabled):
135            super().__init__()
136            self.device = device
137            self.metadata = metadata
138            self.loop = asyncio.get_running_loop()
139            self.queue = asyncio.Queue()
140            self.dequeue_task = None
141            self.event_thread = threading.Thread(
142                target=self.run, args=(USB_ENDPOINT_EVENTS_IN, hci.HCI_EVENT_PACKET)
143            )
144            self.event_thread.stop_event = None
145            self.acl_thread = threading.Thread(
146                target=self.run, args=(USB_ENDPOINT_ACL_IN, hci.HCI_ACL_DATA_PACKET)
147            )
148            self.acl_thread.stop_event = None
149
150            # SCO support is optional
151            self.sco_enabled = sco_enabled
152            if sco_enabled:
153                self.sco_thread = threading.Thread(
154                    target=self.run,
155                    args=(USB_ENDPOINT_SCO_IN, hci.HCI_SYNCHRONOUS_DATA_PACKET),
156                )
157                self.sco_thread.stop_event = None
158
159        def data_received(self, data):
160            self.parser.feed_data(data)
161
162        def enqueue(self, packet):
163            self.queue.put_nowait(packet)
164
165        async def dequeue(self):
166            while True:
167                try:
168                    data = await self.queue.get()
169                except asyncio.CancelledError:
170                    return
171                self.data_received(data)
172
173        def start(self):
174            self.dequeue_task = self.loop.create_task(self.dequeue())
175            self.event_thread.start()
176            self.acl_thread.start()
177            if self.sco_enabled:
178                self.sco_thread.start()
179
180        async def stop(self):
181            # Stop the dequeuing task
182            self.dequeue_task.cancel()
183
184            # Create stop events and wait for them to be signaled
185            self.event_thread.stop_event = asyncio.Event()
186            self.acl_thread.stop_event = asyncio.Event()
187            await self.event_thread.stop_event.wait()
188            await self.acl_thread.stop_event.wait()
189            if self.sco_enabled:
190                await self.sco_thread.stop_event.wait()
191
192        def run(self, endpoint, packet_type):
193            # Read until asked to stop
194            current_thread = threading.current_thread()
195            while current_thread.stop_event is None:
196                try:
197                    # Read, with a timeout of 1 second
198                    data = self.device.read(endpoint, READ_SIZE, timeout=READ_TIMEOUT)
199                    packet = bytes([packet_type]) + data.tobytes()
200                    self.loop.call_soon_threadsafe(self.enqueue, packet)
201                except usb.core.USBTimeoutError:
202                    continue
203                except usb.core.USBError:
204                    # Don't log this: because pyusb doesn't really support multiple
205                    # threads reading at the same time, we can get occasional
206                    # USBError(errno=5) Input/Output errors reported, but they seem to
207                    # be harmless.
208                    # Until support for async or multi-thread support is added to pyusb,
209                    # we'll just live with this as is...
210                    # logger.warning(f'USB read error: {error}')
211                    time.sleep(1)  # Sleep one second to avoid busy looping
212
213            stop_event = current_thread.stop_event
214            self.loop.call_soon_threadsafe(stop_event.set)
215
216    class UsbTransport(Transport):
217        def __init__(self, device, source, sink):
218            super().__init__(source, sink)
219            self.device = device
220
221        async def close(self):
222            await self.source.stop()
223            await self.sink.stop()
224            usb.util.release_interface(self.device, 0)
225            if devices_in_use and device.address in devices_in_use:
226                devices_in_use.remove(device.address)
227
228    usb_find = usb.core.find
229    try:
230        import libusb_package
231    except ImportError:
232        logger.debug('libusb_package is not available')
233    else:
234        usb_find = libusb_package.find
235
236    # Find the device according to the spec moniker
237    power_cycle = False
238    if spec.startswith('!'):
239        power_cycle = True
240        spec = spec[1:]
241    if ':' in spec:
242        vendor_id, product_id = spec.split(':')
243        device = None
244        devices = usb_find(
245            find_all=True, idVendor=int(vendor_id, 16), idProduct=int(product_id, 16)
246        )
247        for d in devices:
248            if d.address in devices_in_use:
249                continue
250            device = d
251            devices_in_use.add(d.address)
252            break
253        if device is None:
254            raise ValueError('device already in use')
255    elif '-' in spec:
256
257        def device_path(device):
258            if device.port_numbers:
259                return f'{device.bus}-{".".join(map(str, device.port_numbers))}'
260            else:
261                return str(device.bus)
262
263        device = usb_find(custom_match=lambda device: device_path(device) == spec)
264    else:
265        device_index = int(spec)
266        devices = list(
267            usb_find(
268                find_all=1,
269                bDeviceClass=USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
270                bDeviceSubClass=USB_DEVICE_SUBCLASS_RF_CONTROLLER,
271                bDeviceProtocol=USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
272            )
273        )
274        if len(devices) > device_index:
275            device = devices[device_index]
276        else:
277            device = None
278
279    if device is None:
280        raise TransportInitError('device not found')
281    logger.debug(f'USB Device: {device}')
282
283    # Power Cycle the device
284    if power_cycle:
285        try:
286            device = await _power_cycle(device)  # type: ignore
287        except Exception as e:
288            logging.debug(e)
289            logging.info(f"Unable to power cycle {hex(device.idVendor)} {hex(device.idProduct)}")  # type: ignore
290
291    # Collect the metadata
292    device_metadata = {'vendor_id': device.idVendor, 'product_id': device.idProduct}
293
294    # Detach the kernel driver if needed
295    if device.is_kernel_driver_active(0):
296        logger.debug("detaching kernel driver")
297        device.detach_kernel_driver(0)
298
299    # Set configuration, if needed
300    try:
301        configuration = device.get_active_configuration()
302    except usb.core.USBError:
303        device.set_configuration()
304        configuration = device.get_active_configuration()
305    interface = configuration[(0, 0)]
306    logger.debug(f'USB Interface: {interface}')
307    usb.util.claim_interface(device, 0)
308
309    # Select an alternate setting for SCO, if available
310    sco_enabled = False
311    # pylint: disable=line-too-long
312    # NOTE: this is disabled for now, because SCO with alternate settings is broken,
313    # see: https://github.com/libusb/libusb/issues/36
314    #
315    # best_packet_size = 0
316    # best_interface = None
317    # sco_enabled = False
318    # for interface in configuration:
319    #     iso_in_endpoint = None
320    #     iso_out_endpoint = None
321    #     for endpoint in interface:
322    #         if (endpoint.bEndpointAddress == USB_ENDPOINT_SCO_IN and
323    #             usb.util.endpoint_direction(endpoint.bEndpointAddress) == usb.util.ENDPOINT_IN and
324    #             usb.util.endpoint_type(endpoint.bmAttributes) == usb.util.ENDPOINT_TYPE_ISO):
325    #             iso_in_endpoint = endpoint
326    #             continue
327    #         if (endpoint.bEndpointAddress == USB_ENDPOINT_SCO_OUT and
328    #             usb.util.endpoint_direction(endpoint.bEndpointAddress) == usb.util.ENDPOINT_OUT and
329    #             usb.util.endpoint_type(endpoint.bmAttributes) == usb.util.ENDPOINT_TYPE_ISO):
330    #             iso_out_endpoint = endpoint
331
332    #     if iso_in_endpoint is not None and iso_out_endpoint is not None:
333    #         if iso_out_endpoint.wMaxPacketSize > best_packet_size:
334    #             best_packet_size = iso_out_endpoint.wMaxPacketSize
335    #             best_interface = interface
336
337    # if best_interface is not None:
338    #     logger.debug(f'SCO enabled, selecting alternate setting (wMaxPacketSize={best_packet_size}): {best_interface}')
339    #     sco_enabled = True
340    #     try:
341    #         device.set_interface_altsetting(
342    #             interface = best_interface.bInterfaceNumber,
343    #             alternate_setting = best_interface.bAlternateSetting
344    #         )
345    #     except usb.USBError:
346    #         logger.warning('failed to set alternate setting')
347
348    packet_source = UsbPacketSource(device, device_metadata, sco_enabled)
349    packet_sink = UsbPacketSink(device)
350    packet_source.start()
351    packet_sink.start()
352
353    return UsbTransport(device, packet_source, packet_sink)
354
355
356async def _power_cycle(device: UsbDevice) -> UsbDevice:
357    """
358    For devices connected to compatible USB hubs: Performs a power cycle on a given USB device.
359    This involves temporarily disabling its port on the hub and then re-enabling it.
360    """
361    device_path = f'{device.bus}-{".".join(map(str, device.port_numbers))}'  # type: ignore
362    hub = _find_hub_by_device_path(device_path)
363
364    if hub:
365        try:
366            device_port = device.port_numbers[-1]  # type: ignore
367            _set_port_status(hub, device_port, False)
368            await asyncio.sleep(POWER_CYCLE_DELAY)
369            _set_port_status(hub, device_port, True)
370            await asyncio.sleep(RESET_DELAY)
371
372            # Device needs to be find again otherwise it will appear as disconnected
373            return usb.core.find(idVendor=device.idVendor, idProduct=device.idProduct)  # type: ignore
374        except USBError as e:
375            logger.error(f"Adjustment needed: Please revise the udev rule for device {hex(device.idVendor)}:{hex(device.idProduct)} for proper recognition.")  # type: ignore
376            logger.error(e)
377
378    return device
379
380
381def _set_port_status(device: UsbDevice, port: int, on: bool):
382    """Sets the power status of a specific port on a USB hub."""
383    device.ctrl_transfer(
384        bmRequestType=CTRL_TYPE_CLASS | CTRL_RECIPIENT_OTHER,
385        bRequest=REQ_SET_FEATURE if on else REQ_CLEAR_FEATURE,
386        wIndex=port,
387        wValue=USB_PORT_FEATURE_POWER,
388    )
389
390
391def _find_device_by_path(sys_path: str) -> Optional[UsbDevice]:
392    """Finds a USB device based on its system path."""
393    bus_num, *port_parts = sys_path.split('-')
394    ports = [int(port) for port in port_parts[0].split('.')]
395    devices = usb.core.find(find_all=True, bus=int(bus_num))
396    if devices:
397        for device in devices:
398            if device.bus == int(bus_num) and list(device.port_numbers) == ports:  # type: ignore
399                return device
400
401    return None
402
403
404def _find_hub_by_device_path(sys_path: str) -> Optional[UsbDevice]:
405    """Finds the USB hub associated with a specific device path."""
406    hub_sys_path = sys_path.rsplit('.', 1)[0]
407    hub_device = _find_device_by_path(hub_sys_path)
408
409    if hub_device is None:
410        return None
411    else:
412        return hub_device if _is_hub(hub_device) else None
413
414
415def _is_hub(device: UsbDevice) -> bool:
416    """Checks if a USB device is a hub"""
417    if device.bDeviceClass == CLASS_HUB:  # type: ignore
418        return True
419    for config in device:
420        for interface in config:
421            if interface.bInterfaceClass == CLASS_HUB:  # type: ignore
422                return True
423    return False
424