1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19import asyncio
20import logging
21import threading
22import ctypes
23import platform
24
25import usb1
26
27from bumble.transport.common import Transport, BaseSource, TransportInitError
28from bumble import hci
29from bumble.colors import color
30
31
32# -----------------------------------------------------------------------------
33# Logging
34# -----------------------------------------------------------------------------
35logger = logging.getLogger(__name__)
36
37
38# -----------------------------------------------------------------------------
39def load_libusb():
40    '''
41    Attempt to load the libusb-1.0 C library from libusb_package in site-packages.
42    If the library exists, we create a DLL object and initialize the usb1 backend.
43    This only needs to be done once, but before a usb1.USBContext is created.
44    If the library does not exists, do nothing and usb1 will search default system paths
45    when usb1.USBContext is created.
46    '''
47    try:
48        import libusb_package
49    except ImportError:
50        logger.debug('libusb_package is not available')
51    else:
52        if libusb_path := libusb_package.get_library_path():
53            logger.debug(f'loading libusb library at {libusb_path}')
54            dll_loader = (
55                ctypes.WinDLL if platform.system() == 'Windows' else ctypes.CDLL
56            )
57            libusb_dll = dll_loader(
58                str(libusb_path), use_errno=True, use_last_error=True
59            )
60            usb1.loadLibrary(libusb_dll)
61
62
63async def open_usb_transport(spec: str) -> Transport:
64    '''
65    Open a USB transport.
66    The moniker string has this syntax:
67    either <index> or
68    <vendor>:<product> or
69    <vendor>:<product>/<serial-number>] or
70    <vendor>:<product>#<index>
71    With <index> as the 0-based index to select amongst all the devices that appear
72    to be supporting Bluetooth HCI (0 being the first one), or
73    Where <vendor> and <product> are the vendor ID and product ID in hexadecimal. The
74    /<serial-number> suffix or #<index> suffix max be specified when more than one
75    device with the same vendor and product identifiers are present.
76
77    In addition, if the moniker ends with the symbol "!", the device will be used in
78    "forced" mode:
79    the first USB interface of the device will be used, regardless of the interface
80    class/subclass.
81    This may be useful for some devices that use a custom class/subclass but may
82    nonetheless work as-is.
83
84    Examples:
85    0 --> the first BT USB dongle
86    04b4:f901 --> the BT USB dongle with vendor=04b4 and product=f901
87    04b4:f901#2 --> the third USB device with vendor=04b4 and product=f901
88    04b4:f901/00E04C239987 --> the BT USB dongle with vendor=04b4 and product=f901 and
89    serial number 00E04C239987
90    usb:0B05:17CB! --> the BT USB dongle vendor=0B05 and product=17CB, in "forced" mode.
91    '''
92
93    # pylint: disable=invalid-name
94    USB_RECIPIENT_DEVICE = 0x00
95    USB_REQUEST_TYPE_CLASS = 0x01 << 5
96    USB_DEVICE_CLASS_DEVICE = 0x00
97    USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
98    USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
99    USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01
100    USB_ENDPOINT_TRANSFER_TYPE_BULK = 0x02
101    USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT = 0x03
102    USB_ENDPOINT_IN = 0x80
103
104    USB_BT_HCI_CLASS_TUPLE = (
105        USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
106        USB_DEVICE_SUBCLASS_RF_CONTROLLER,
107        USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
108    )
109
110    READ_SIZE = 4096
111
112    class UsbPacketSink:
113        def __init__(self, device, acl_out):
114            self.device = device
115            self.acl_out = acl_out
116            self.acl_out_transfer = device.getTransfer()
117            self.acl_out_transfer_ready = asyncio.Semaphore(1)
118            self.packets: asyncio.Queue[bytes] = (
119                asyncio.Queue()
120            )  # Queue of packets waiting to be sent
121            self.loop = asyncio.get_running_loop()
122            self.queue_task = None
123            self.cancel_done = self.loop.create_future()
124            self.closed = False
125
126        def start(self):
127            self.queue_task = asyncio.create_task(self.process_queue())
128
129        def on_packet(self, packet):
130            # Ignore packets if we're closed
131            if self.closed:
132                return
133
134            if len(packet) == 0:
135                logger.warning('packet too short')
136                return
137
138            # Queue the packet
139            self.packets.put_nowait(packet)
140
141        def transfer_callback(self, transfer):
142            self.loop.call_soon_threadsafe(self.acl_out_transfer_ready.release)
143            status = transfer.getStatus()
144
145            # pylint: disable=no-member
146            if status == usb1.TRANSFER_CANCELLED:
147                self.loop.call_soon_threadsafe(self.cancel_done.set_result, None)
148                return
149
150            if status != usb1.TRANSFER_COMPLETED:
151                logger.warning(
152                    color(f'!!! OUT transfer not completed: status={status}', 'red')
153                )
154
155        async def process_queue(self):
156            while True:
157                # Wait for a packet to transfer.
158                packet = await self.packets.get()
159
160                # Wait until we can start a transfer.
161                await self.acl_out_transfer_ready.acquire()
162
163                # Transfer the packet.
164                packet_type = packet[0]
165                if packet_type == hci.HCI_ACL_DATA_PACKET:
166                    self.acl_out_transfer.setBulk(
167                        self.acl_out, packet[1:], callback=self.transfer_callback
168                    )
169                    self.acl_out_transfer.submit()
170                elif packet_type == hci.HCI_COMMAND_PACKET:
171                    self.acl_out_transfer.setControl(
172                        USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
173                        0,
174                        0,
175                        0,
176                        packet[1:],
177                        callback=self.transfer_callback,
178                    )
179                    self.acl_out_transfer.submit()
180                else:
181                    logger.warning(
182                        color(f'unsupported packet type {packet_type}', 'red')
183                    )
184
185        def close(self):
186            self.closed = True
187            if self.queue_task:
188                self.queue_task.cancel()
189
190        async def terminate(self):
191            if not self.closed:
192                self.close()
193
194            # Empty the packet queue so that we don't send any more data
195            while not self.packets.empty():
196                self.packets.get_nowait()
197
198            # If we have a transfer in flight, cancel it
199            if self.acl_out_transfer.isSubmitted():
200                # Try to cancel the transfer, but that may fail because it may have
201                # already completed
202                try:
203                    self.acl_out_transfer.cancel()
204
205                    logger.debug('waiting for OUT transfer cancellation to be done...')
206                    await self.cancel_done
207                    logger.debug('OUT transfer cancellation done')
208                except usb1.USBError:
209                    logger.debug('OUT transfer likely already completed')
210
211    class UsbPacketSource(asyncio.Protocol, BaseSource):
212        def __init__(self, device, metadata, acl_in, events_in):
213            super().__init__()
214            self.device = device
215            self.metadata = metadata
216            self.acl_in = acl_in
217            self.acl_in_transfer = None
218            self.events_in = events_in
219            self.events_in_transfer = None
220            self.loop = asyncio.get_running_loop()
221            self.queue = asyncio.Queue()
222            self.dequeue_task = None
223            self.cancel_done = {
224                hci.HCI_EVENT_PACKET: self.loop.create_future(),
225                hci.HCI_ACL_DATA_PACKET: self.loop.create_future(),
226            }
227            self.closed = False
228
229        def start(self):
230            # Set up transfer objects for input
231            self.events_in_transfer = device.getTransfer()
232            self.events_in_transfer.setInterrupt(
233                self.events_in,
234                READ_SIZE,
235                callback=self.transfer_callback,
236                user_data=hci.HCI_EVENT_PACKET,
237            )
238            self.events_in_transfer.submit()
239
240            self.acl_in_transfer = device.getTransfer()
241            self.acl_in_transfer.setBulk(
242                self.acl_in,
243                READ_SIZE,
244                callback=self.transfer_callback,
245                user_data=hci.HCI_ACL_DATA_PACKET,
246            )
247            self.acl_in_transfer.submit()
248
249            self.dequeue_task = self.loop.create_task(self.dequeue())
250
251        @property
252        def usb_transfer_submitted(self):
253            return (
254                self.events_in_transfer.isSubmitted()
255                or self.acl_in_transfer.isSubmitted()
256            )
257
258        def transfer_callback(self, transfer):
259            packet_type = transfer.getUserData()
260            status = transfer.getStatus()
261
262            # pylint: disable=no-member
263            if status == usb1.TRANSFER_COMPLETED:
264                packet = (
265                    bytes([packet_type])
266                    + transfer.getBuffer()[: transfer.getActualLength()]
267                )
268                self.loop.call_soon_threadsafe(self.queue.put_nowait, packet)
269
270                # Re-submit the transfer so we can receive more data
271                transfer.submit()
272            elif status == usb1.TRANSFER_CANCELLED:
273                self.loop.call_soon_threadsafe(
274                    self.cancel_done[packet_type].set_result, None
275                )
276            else:
277                logger.warning(
278                    color(f'!!! IN transfer not completed: status={status}', 'red')
279                )
280                self.loop.call_soon_threadsafe(self.on_transport_lost)
281
282        async def dequeue(self):
283            while not self.closed:
284                try:
285                    packet = await self.queue.get()
286                except asyncio.CancelledError:
287                    return
288                if self.sink:
289                    try:
290                        self.sink.on_packet(packet)
291                    except Exception as error:
292                        logger.exception(
293                            color(f'!!! Exception in sink.on_packet: {error}', 'red')
294                        )
295
296        def close(self):
297            self.closed = True
298
299        async def terminate(self):
300            if not self.closed:
301                self.close()
302
303            self.dequeue_task.cancel()
304
305            # Cancel the transfers
306            for transfer in (self.events_in_transfer, self.acl_in_transfer):
307                if transfer.isSubmitted():
308                    # Try to cancel the transfer, but that may fail because it may have
309                    # already completed
310                    packet_type = transfer.getUserData()
311                    try:
312                        transfer.cancel()
313                        logger.debug(
314                            f'waiting for IN[{packet_type}] transfer cancellation '
315                            'to be done...'
316                        )
317                        await self.cancel_done[packet_type]
318                        logger.debug(f'IN[{packet_type}] transfer cancellation done')
319                    except usb1.USBError:
320                        logger.debug(
321                            f'IN[{packet_type}] transfer likely already completed'
322                        )
323
324    class UsbTransport(Transport):
325        def __init__(self, context, device, interface, setting, source, sink):
326            super().__init__(source, sink)
327            self.context = context
328            self.device = device
329            self.interface = interface
330            self.loop = asyncio.get_running_loop()
331            self.event_loop_done = self.loop.create_future()
332
333            # Get exclusive access
334            device.claimInterface(interface)
335
336            # Set the alternate setting if not the default
337            if setting != 0:
338                device.setInterfaceAltSetting(interface, setting)
339
340            # The source and sink can now start
341            source.start()
342            sink.start()
343
344            # Create a thread to process events
345            self.event_thread = threading.Thread(target=self.run)
346            self.event_thread.start()
347
348        def run(self):
349            logger.debug('starting USB event loop')
350            while self.source.usb_transfer_submitted:
351                # pylint: disable=no-member
352                try:
353                    self.context.handleEvents()
354                except usb1.USBErrorInterrupted:
355                    pass
356
357            logger.debug('USB event loop done')
358            self.loop.call_soon_threadsafe(self.event_loop_done.set_result, None)
359
360        async def close(self):
361            self.source.close()
362            self.sink.close()
363            await self.source.terminate()
364            await self.sink.terminate()
365            self.device.releaseInterface(self.interface)
366            self.device.close()
367            self.context.close()
368
369            # Wait for the thread to terminate
370            await self.event_loop_done
371
372    # Find the device according to the spec moniker
373    load_libusb()
374    context = usb1.USBContext()
375    context.open()
376    try:
377        found = None
378
379        if spec.endswith('!'):
380            spec = spec[:-1]
381            forced_mode = True
382        else:
383            forced_mode = False
384
385        if ':' in spec:
386            vendor_id, product_id = spec.split(':')
387            serial_number = None
388            device_index = 0
389            if '/' in product_id:
390                product_id, serial_number = product_id.split('/')
391            elif '#' in product_id:
392                product_id, device_index_str = product_id.split('#')
393                device_index = int(device_index_str)
394
395            for device in context.getDeviceIterator(skip_on_error=True):
396                try:
397                    device_serial_number = device.getSerialNumber()
398                except usb1.USBError:
399                    device_serial_number = None
400                if (
401                    device.getVendorID() == int(vendor_id, 16)
402                    and device.getProductID() == int(product_id, 16)
403                    and (serial_number is None or serial_number == device_serial_number)
404                ):
405                    if device_index == 0:
406                        found = device
407                        break
408                    device_index -= 1
409                device.close()
410        elif '-' in spec:
411
412            def device_path(device):
413                return f'{device.getBusNumber()}-{".".join(map(str, device.getPortNumberList()))}'
414
415            for device in context.getDeviceIterator(skip_on_error=True):
416                if device_path(device) == spec:
417                    found = device
418                    break
419                device.close()
420        else:
421            # Look for a compatible device by index
422            def device_is_bluetooth_hci(device):
423                # Check if the device class indicates a match
424                if (
425                    device.getDeviceClass(),
426                    device.getDeviceSubClass(),
427                    device.getDeviceProtocol(),
428                ) == USB_BT_HCI_CLASS_TUPLE:
429                    return True
430
431                # If the device class is 'Device', look for a matching interface
432                if device.getDeviceClass() == USB_DEVICE_CLASS_DEVICE:
433                    for configuration in device:
434                        for interface in configuration:
435                            for setting in interface:
436                                if (
437                                    setting.getClass(),
438                                    setting.getSubClass(),
439                                    setting.getProtocol(),
440                                ) == USB_BT_HCI_CLASS_TUPLE:
441                                    return True
442
443                return False
444
445            device_index = int(spec)
446            for device in context.getDeviceIterator(skip_on_error=True):
447                if device_is_bluetooth_hci(device):
448                    if device_index == 0:
449                        found = device
450                        break
451                    device_index -= 1
452                device.close()
453
454        if found is None:
455            context.close()
456            raise TransportInitError('device not found')
457
458        logger.debug(f'USB Device: {found}')
459
460        # Look for the first interface with the right class and endpoints
461        def find_endpoints(device):
462            # pylint: disable-next=too-many-nested-blocks
463            for configuration_index, configuration in enumerate(device):
464                interface = None
465                for interface in configuration:
466                    setting = None
467                    for setting in interface:
468                        if (
469                            not forced_mode
470                            and (
471                                setting.getClass(),
472                                setting.getSubClass(),
473                                setting.getProtocol(),
474                            )
475                            != USB_BT_HCI_CLASS_TUPLE
476                        ):
477                            continue
478
479                        events_in = None
480                        acl_in = None
481                        acl_out = None
482                        for endpoint in setting:
483                            attributes = endpoint.getAttributes()
484                            address = endpoint.getAddress()
485                            if attributes & 0x03 == USB_ENDPOINT_TRANSFER_TYPE_BULK:
486                                if address & USB_ENDPOINT_IN and acl_in is None:
487                                    acl_in = address
488                                elif acl_out is None:
489                                    acl_out = address
490                            elif (
491                                attributes & 0x03
492                                == USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT
493                            ):
494                                if address & USB_ENDPOINT_IN and events_in is None:
495                                    events_in = address
496
497                        # Return if we found all 3 endpoints
498                        if (
499                            acl_in is not None
500                            and acl_out is not None
501                            and events_in is not None
502                        ):
503                            return (
504                                configuration_index + 1,
505                                setting.getNumber(),
506                                setting.getAlternateSetting(),
507                                acl_in,
508                                acl_out,
509                                events_in,
510                            )
511
512                        logger.debug(
513                            f'skipping configuration {configuration_index + 1} / '
514                            f'interface {setting.getNumber()}'
515                        )
516
517            return None
518
519        endpoints = find_endpoints(found)
520        if endpoints is None:
521            raise TransportInitError('no compatible interface found for device')
522        (configuration, interface, setting, acl_in, acl_out, events_in) = endpoints
523        logger.debug(
524            f'selected endpoints: configuration={configuration}, '
525            f'interface={interface}, '
526            f'setting={setting}, '
527            f'acl_in=0x{acl_in:02X}, '
528            f'acl_out=0x{acl_out:02X}, '
529            f'events_in=0x{events_in:02X}, '
530        )
531
532        device_metadata = {
533            'vendor_id': found.getVendorID(),
534            'product_id': found.getProductID(),
535        }
536        device = found.open()
537
538        # Auto-detach the kernel driver if supported
539        # pylint: disable=no-member
540        if usb1.hasCapability(usb1.CAP_SUPPORTS_DETACH_KERNEL_DRIVER):
541            try:
542                logger.debug('auto-detaching kernel driver')
543                device.setAutoDetachKernelDriver(True)
544            except usb1.USBError as error:
545                logger.warning(f'unable to auto-detach kernel driver: {error}')
546
547        # Set the configuration if needed
548        try:
549            current_configuration = device.getConfiguration()
550            logger.debug(f'current configuration = {current_configuration}')
551        except usb1.USBError:
552            current_configuration = 0
553
554        if current_configuration != configuration:
555            try:
556                logger.debug(f'setting configuration {configuration}')
557                device.setConfiguration(configuration)
558            except usb1.USBError:
559                logger.warning('failed to set configuration')
560
561        source = UsbPacketSource(device, device_metadata, acl_in, events_in)
562        sink = UsbPacketSink(device, acl_out)
563        return UsbTransport(context, device, interface, setting, source, sink)
564    except usb1.USBError as error:
565        logger.warning(color(f'!!! failed to open USB device: {error}', 'red'))
566        context.close()
567        raise
568