1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import os
20import struct
21import logging
22import click
23
24from bumble import l2cap
25from bumble.colors import color
26from bumble.device import Device, Peer
27from bumble.core import AdvertisingData
28from bumble.gatt import Service, Characteristic, CharacteristicValue
29from bumble.utils import AsyncRunner
30from bumble.transport import open_transport_or_link
31from bumble.hci import HCI_Constant
32
33
34# -----------------------------------------------------------------------------
35# Constants
36# -----------------------------------------------------------------------------
37GG_GATTLINK_SERVICE_UUID = 'ABBAFF00-E56A-484C-B832-8B17CF6CBFE8'
38GG_GATTLINK_RX_CHARACTERISTIC_UUID = 'ABBAFF01-E56A-484C-B832-8B17CF6CBFE8'
39GG_GATTLINK_TX_CHARACTERISTIC_UUID = 'ABBAFF02-E56A-484C-B832-8B17CF6CBFE8'
40GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID = (
41    'ABBAFF03-E56A-484C-B832-8B17CF6CBFE8'
42)
43
44GG_PREFERRED_MTU = 256
45
46
47# -----------------------------------------------------------------------------
48class GattlinkL2capEndpoint:
49    def __init__(self):
50        self.l2cap_channel = None
51        self.l2cap_packet = b''
52        self.l2cap_packet_size = 0
53
54    # Called when an L2CAP SDU has been received
55    def on_coc_sdu(self, sdu):
56        print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
57        while len(sdu):
58            if self.l2cap_packet_size == 0:
59                # Expect a new packet
60                self.l2cap_packet_size = sdu[0] + 1
61                sdu = sdu[1:]
62            else:
63                bytes_needed = self.l2cap_packet_size - len(self.l2cap_packet)
64                chunk = min(bytes_needed, len(sdu))
65                self.l2cap_packet += sdu[:chunk]
66                sdu = sdu[chunk:]
67                if len(self.l2cap_packet) == self.l2cap_packet_size:
68                    self.on_l2cap_packet(self.l2cap_packet)
69                    self.l2cap_packet = b''
70                    self.l2cap_packet_size = 0
71
72
73# -----------------------------------------------------------------------------
74class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
75    def __init__(self, device, peer_address):
76        super().__init__()
77        self.device = device
78        self.peer_address = peer_address
79        self.peer = None
80        self.tx_socket = None
81        self.rx_characteristic = None
82        self.tx_characteristic = None
83        self.l2cap_psm_characteristic = None
84
85        device.listener = self
86
87    async def start(self):
88        # Connect to the peer
89        print(f'=== Connecting to {self.peer_address}...')
90        await self.device.connect(self.peer_address)
91
92    async def connect_l2cap(self, psm):
93        print(color(f'### Connecting with L2CAP on PSM = {psm}', 'yellow'))
94        try:
95            self.l2cap_channel = await self.peer.connection.open_l2cap_channel(psm)
96            print(color('*** Connected', 'yellow'), self.l2cap_channel)
97            self.l2cap_channel.sink = self.on_coc_sdu
98
99        except Exception as error:
100            print(color(f'!!! Connection failed: {error}', 'red'))
101
102    @AsyncRunner.run_in_task()
103    # pylint: disable=invalid-overridden-method
104    async def on_connection(self, connection):
105        print(f'=== Connected to {connection}')
106        self.peer = Peer(connection)
107
108        # Request a larger MTU than the default
109        server_mtu = await self.peer.request_mtu(GG_PREFERRED_MTU)
110        print(f'### Server MTU = {server_mtu}')
111
112        # Discover all services
113        print(color('=== Discovering services', 'yellow'))
114        await self.peer.discover_service(GG_GATTLINK_SERVICE_UUID)
115        print(color('=== Services discovered', 'yellow'), self.peer.services)
116        for service in self.peer.services:
117            print(service)
118        services = self.peer.get_services_by_uuid(GG_GATTLINK_SERVICE_UUID)
119        if not services:
120            print(color('!!! Gattlink service not found', 'red'))
121            return
122
123        # Use the first Gattlink (there should only be one anyway)
124        gattlink_service = services[0]
125
126        # Discover all the characteristics for the service
127        characteristics = await gattlink_service.discover_characteristics()
128        print(color('=== Characteristics discovered', 'yellow'))
129        for characteristic in characteristics:
130            if characteristic.uuid == GG_GATTLINK_RX_CHARACTERISTIC_UUID:
131                self.rx_characteristic = characteristic
132            elif characteristic.uuid == GG_GATTLINK_TX_CHARACTERISTIC_UUID:
133                self.tx_characteristic = characteristic
134            elif (
135                characteristic.uuid == GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID
136            ):
137                self.l2cap_psm_characteristic = characteristic
138        print('RX:', self.rx_characteristic)
139        print('TX:', self.tx_characteristic)
140        print('PSM:', self.l2cap_psm_characteristic)
141
142        if self.l2cap_psm_characteristic:
143            # Subscribe to and then read the PSM value
144            await self.peer.subscribe(
145                self.l2cap_psm_characteristic, self.on_l2cap_psm_received
146            )
147            psm_bytes = await self.peer.read_value(self.l2cap_psm_characteristic)
148            psm = struct.unpack('<H', psm_bytes)[0]
149            await self.connect_l2cap(psm)
150        elif self.tx_characteristic:
151            # Subscribe to TX
152            await self.peer.subscribe(self.tx_characteristic, self.on_tx_received)
153            print(color('=== Subscribed to Gattlink TX', 'yellow'))
154        else:
155            print(color('!!! No Gattlink TX or PSM found', 'red'))
156
157    def on_connection_failure(self, error):
158        print(color(f'!!! Connection failed: {error}'))
159
160    def on_disconnection(self, reason):
161        print(
162            color(
163                f'!!! Disconnected from {self.peer}, '
164                f'reason={HCI_Constant.error_name(reason)}',
165                'red',
166            )
167        )
168        self.tx_characteristic = None
169        self.rx_characteristic = None
170        self.peer = None
171
172    # Called when an L2CAP packet has been received
173    def on_l2cap_packet(self, packet):
174        print(color(f'<<< [L2CAP PACKET]: {len(packet)} bytes', 'cyan'))
175        print(color('>>> [UDP]', 'magenta'))
176        self.tx_socket.sendto(packet)
177
178    # Called by the GATT client when a notification is received
179    def on_tx_received(self, value):
180        print(color(f'<<< [GATT TX]: {len(value)} bytes', 'cyan'))
181        if self.tx_socket:
182            print(color('>>> [UDP]', 'magenta'))
183            self.tx_socket.sendto(value)
184
185    # Called by asyncio when the UDP socket is created
186    def on_l2cap_psm_received(self, value):
187        psm = struct.unpack('<H', value)[0]
188        asyncio.create_task(self.connect_l2cap(psm))
189
190    # Called by asyncio when the UDP socket is created
191    def connection_made(self, transport):
192        pass
193
194    # Called by asyncio when a UDP datagram is received
195    def datagram_received(self, data, _address):
196        print(color(f'<<< [UDP]: {len(data)} bytes', 'green'))
197
198        if self.l2cap_channel:
199            print(color('>>> [L2CAP]', 'yellow'))
200            self.l2cap_channel.write(bytes([len(data) - 1]) + data)
201        elif self.peer and self.rx_characteristic:
202            print(color('>>> [GATT RX]', 'yellow'))
203            asyncio.create_task(self.peer.write_value(self.rx_characteristic, data))
204
205
206# -----------------------------------------------------------------------------
207class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
208    def __init__(self, device: Device):
209        super().__init__()
210        self.device = device
211        self.peer = None
212        self.tx_socket = None
213        self.tx_subscriber = None
214        self.rx_characteristic = None
215        self.transport = None
216
217        # Register as a listener
218        device.listener = self
219
220        # Listen for incoming L2CAP CoC connections
221        psm = 0xFB
222        device.create_l2cap_server(
223            spec=l2cap.LeCreditBasedChannelSpec(
224                psm=0xFB,
225            ),
226            handler=self.on_coc,
227        )
228        print(f'### Listening for CoC connection on PSM {psm}')
229
230        # Setup the Gattlink service
231        self.rx_characteristic = Characteristic(
232            GG_GATTLINK_RX_CHARACTERISTIC_UUID,
233            Characteristic.WRITE_WITHOUT_RESPONSE,
234            Characteristic.WRITEABLE,
235            CharacteristicValue(write=self.on_rx_write),
236        )
237        self.tx_characteristic = Characteristic(
238            GG_GATTLINK_TX_CHARACTERISTIC_UUID,
239            Characteristic.Properties.NOTIFY,
240            Characteristic.READABLE,
241        )
242        self.tx_characteristic.on('subscription', self.on_tx_subscription)
243        self.psm_characteristic = Characteristic(
244            GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID,
245            Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
246            Characteristic.READABLE,
247            bytes([psm, 0]),
248        )
249        gattlink_service = Service(
250            GG_GATTLINK_SERVICE_UUID,
251            [self.rx_characteristic, self.tx_characteristic, self.psm_characteristic],
252        )
253        device.add_services([gattlink_service])
254        device.advertising_data = bytes(
255            AdvertisingData(
256                [
257                    (AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')),
258                    (
259                        AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
260                        bytes(
261                            reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8'))
262                        ),
263                    ),
264                ]
265            )
266        )
267
268    async def start(self):
269        await self.device.start_advertising()
270
271    # Called by asyncio when the UDP socket is created
272    def connection_made(self, transport):
273        self.transport = transport
274
275    # Called by asyncio when a UDP datagram is received
276    def datagram_received(self, data, _address):
277        print(color(f'<<< [UDP]: {len(data)} bytes', 'green'))
278
279        if self.l2cap_channel:
280            print(color('>>> [L2CAP]', 'yellow'))
281            self.l2cap_channel.write(bytes([len(data) - 1]) + data)
282        elif self.tx_subscriber:
283            print(color('>>> [GATT TX]', 'yellow'))
284            self.tx_characteristic.value = data
285            asyncio.create_task(self.device.notify_subscribers(self.tx_characteristic))
286
287    # Called when a write to the RX characteristic has been received
288    def on_rx_write(self, _connection, data):
289        print(color(f'<<< [GATT RX]: {len(data)} bytes', 'cyan'))
290        print(color('>>> [UDP]', 'magenta'))
291        self.tx_socket.sendto(data)
292
293    # Called when the subscription to the TX characteristic has changed
294    def on_tx_subscription(self, peer, enabled):
295        print(
296            f'### [GATT TX] subscription from {peer}: '
297            f'{"enabled" if enabled else "disabled"}'
298        )
299        if enabled:
300            self.tx_subscriber = peer
301        else:
302            self.tx_subscriber = None
303
304    # Called when an L2CAP packet is received
305    def on_l2cap_packet(self, packet):
306        print(color(f'<<< [L2CAP PACKET]: {len(packet)} bytes', 'cyan'))
307        print(color('>>> [UDP]', 'magenta'))
308        self.tx_socket.sendto(packet)
309
310    # Called when a new connection is established
311    def on_coc(self, channel):
312        print('*** CoC Connection', channel)
313        self.l2cap_channel = channel
314        channel.sink = self.on_coc_sdu
315
316
317# -----------------------------------------------------------------------------
318async def run(
319    hci_transport,
320    device_address,
321    role_or_peer_address,
322    send_host,
323    send_port,
324    receive_host,
325    receive_port,
326):
327    print('<<< connecting to HCI...')
328    async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
329        print('<<< connected')
330
331        # Instantiate a bridge object
332        device = Device.with_hci('Bumble GG', device_address, hci_source, hci_sink)
333
334        # Instantiate a bridge object
335        if role_or_peer_address == 'node':
336            bridge = GattlinkNodeBridge(device)
337        else:
338            bridge = GattlinkHubBridge(device, role_or_peer_address)
339
340        # Create a UDP to RX bridge (receive from UDP, send to RX)
341        loop = asyncio.get_running_loop()
342        await loop.create_datagram_endpoint(
343            lambda: bridge, local_addr=(receive_host, receive_port)
344        )
345
346        # Create a UDP to TX bridge (receive from TX, send to UDP)
347        bridge.tx_socket, _ = await loop.create_datagram_endpoint(
348            asyncio.DatagramProtocol,
349            remote_addr=(send_host, send_port),
350        )
351
352        await device.power_on()
353        await bridge.start()
354
355        # Wait until the source terminates
356        await hci_source.wait_for_termination()
357
358
359@click.command()
360@click.argument('hci_transport')
361@click.argument('device_address')
362@click.argument('role_or_peer_address')
363@click.option(
364    '-sh', '--send-host', type=str, default='127.0.0.1', help='UDP host to send to'
365)
366@click.option('-sp', '--send-port', type=int, default=9001, help='UDP port to send to')
367@click.option(
368    '-rh',
369    '--receive-host',
370    type=str,
371    default='127.0.0.1',
372    help='UDP host to receive on',
373)
374@click.option(
375    '-rp', '--receive-port', type=int, default=9000, help='UDP port to receive on'
376)
377def main(
378    hci_transport,
379    device_address,
380    role_or_peer_address,
381    send_host,
382    send_port,
383    receive_host,
384    receive_port,
385):
386    asyncio.run(
387        run(
388            hci_transport,
389            device_address,
390            role_or_peer_address,
391            send_host,
392            send_port,
393            receive_host,
394            receive_port,
395        )
396    )
397
398
399# -----------------------------------------------------------------------------
400logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
401if __name__ == '__main__':
402    main()
403