1# Copyright 2024 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import logging
20import os
21import time
22from typing import Optional
23
24import click
25
26from bumble.colors import color
27from bumble.device import Device, DeviceConfiguration, Connection
28from bumble import core
29from bumble import hci
30from bumble import rfcomm
31from bumble import transport
32from bumble import utils
33
34
35# -----------------------------------------------------------------------------
36# Constants
37# -----------------------------------------------------------------------------
38DEFAULT_RFCOMM_UUID = "E6D55659-C8B4-4B85-96BB-B1143AF6D3AE"
39DEFAULT_MTU = 4096
40DEFAULT_CLIENT_TCP_PORT = 9544
41DEFAULT_SERVER_TCP_PORT = 9545
42
43TRACE_MAX_SIZE = 48
44
45
46# -----------------------------------------------------------------------------
47class Tracer:
48    """
49    Trace data buffers transmitted from one endpoint to another, with stats.
50    """
51
52    def __init__(self, channel_name: str) -> None:
53        self.channel_name = channel_name
54        self.last_ts: float = 0.0
55
56    def trace_data(self, data: bytes) -> None:
57        now = time.time()
58        elapsed_s = now - self.last_ts if self.last_ts else 0
59        elapsed_ms = int(elapsed_s * 1000)
60        instant_throughput_kbps = ((len(data) / elapsed_s) / 1000) if elapsed_s else 0.0
61
62        hex_str = data[:TRACE_MAX_SIZE].hex() + (
63            "..." if len(data) > TRACE_MAX_SIZE else ""
64        )
65        print(
66            f"[{self.channel_name}] {len(data):4} bytes "
67            f"(+{elapsed_ms:4}ms, {instant_throughput_kbps: 7.2f}kB/s) "
68            f" {hex_str}"
69        )
70
71        self.last_ts = now
72
73
74# -----------------------------------------------------------------------------
75class ServerBridge:
76    """
77    RFCOMM server bridge: waits for a peer to connect an RFCOMM channel.
78    The RFCOMM channel may be associated with a UUID published in an SDP service
79    description, or simply be on a system-assigned channel number.
80    When the connection is made, the bridge connects a TCP socket to a remote host and
81    bridges the data in both directions, with flow control.
82    When the RFCOMM channel is closed, the bridge disconnects the TCP socket
83    and waits for a new channel to be connected.
84    """
85
86    READ_CHUNK_SIZE = 4096
87
88    def __init__(
89        self, channel: int, uuid: str, trace: bool, tcp_host: str, tcp_port: int
90    ) -> None:
91        self.device: Optional[Device] = None
92        self.channel = channel
93        self.uuid = uuid
94        self.tcp_host = tcp_host
95        self.tcp_port = tcp_port
96        self.rfcomm_channel: Optional[rfcomm.DLC] = None
97        self.tcp_tracer: Optional[Tracer]
98        self.rfcomm_tracer: Optional[Tracer]
99
100        if trace:
101            self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))
102            self.rfcomm_tracer = Tracer(color("TCP->RFCOMM", "magenta"))
103        else:
104            self.rfcomm_tracer = None
105            self.tcp_tracer = None
106
107    async def start(self, device: Device) -> None:
108        self.device = device
109
110        # Create and register a server
111        rfcomm_server = rfcomm.Server(self.device)
112
113        # Listen for incoming DLC connections
114        self.channel = rfcomm_server.listen(self.on_rfcomm_channel, self.channel)
115
116        # Setup the SDP to advertise this channel
117        service_record_handle = 0x00010001
118        self.device.sdp_service_records = {
119            service_record_handle: rfcomm.make_service_sdp_records(
120                service_record_handle, self.channel, core.UUID(self.uuid)
121            )
122        }
123
124        # We're ready for a connection
125        self.device.on("connection", self.on_connection)
126        await self.set_available(True)
127
128        print(
129            color(
130                (
131                    f"### Listening for RFCOMM connection on {device.public_address}, "
132                    f"channel {self.channel}"
133                ),
134                "yellow",
135            )
136        )
137
138    async def set_available(self, available: bool):
139        # Become discoverable and connectable
140        assert self.device
141        await self.device.set_connectable(available)
142        await self.device.set_discoverable(available)
143
144    def on_connection(self, connection):
145        print(color(f"@@@ Bluetooth connection: {connection}", "blue"))
146        connection.on("disconnection", self.on_disconnection)
147
148        # Don't accept new connections until we're disconnected
149        utils.AsyncRunner.spawn(self.set_available(False))
150
151    def on_disconnection(self, reason: int):
152        print(
153            color("@@@ Bluetooth disconnection:", "red"),
154            hci.HCI_Constant.error_name(reason),
155        )
156
157        # We're ready for a new connection
158        utils.AsyncRunner.spawn(self.set_available(True))
159
160    # Called when an RFCOMM channel is established
161    @utils.AsyncRunner.run_in_task()
162    async def on_rfcomm_channel(self, rfcomm_channel):
163        print(color("*** RFCOMM channel:", "cyan"), rfcomm_channel)
164
165        # Connect to the TCP server
166        print(
167            color(
168                f"### Connecting to TCP {self.tcp_host}:{self.tcp_port}",
169                "yellow",
170            )
171        )
172        try:
173            reader, writer = await asyncio.open_connection(self.tcp_host, self.tcp_port)
174        except OSError:
175            print(color("!!! Connection failed", "red"))
176            await rfcomm_channel.disconnect()
177            return
178
179        # Pipe data from RFCOMM to TCP
180        def on_rfcomm_channel_closed():
181            print(color("*** RFCOMM channel closed", "cyan"))
182            writer.close()
183
184        def write_rfcomm_data(data):
185            if self.rfcomm_tracer:
186                self.rfcomm_tracer.trace_data(data)
187
188            writer.write(data)
189
190        rfcomm_channel.sink = write_rfcomm_data
191        rfcomm_channel.on("close", on_rfcomm_channel_closed)
192
193        # Pipe data from TCP to RFCOMM
194        while True:
195            try:
196                data = await reader.read(self.READ_CHUNK_SIZE)
197
198                if len(data) == 0:
199                    print(color("### TCP end of stream", "yellow"))
200                    if rfcomm_channel.state == rfcomm.DLC.State.CONNECTED:
201                        await rfcomm_channel.disconnect()
202                    return
203
204                if self.tcp_tracer:
205                    self.tcp_tracer.trace_data(data)
206
207                rfcomm_channel.write(data)
208                await rfcomm_channel.drain()
209            except Exception as error:
210                print(f"!!! Exception: {error}")
211                break
212
213        writer.close()
214        await writer.wait_closed()
215        print(color("~~~ Bye bye", "magenta"))
216
217
218# -----------------------------------------------------------------------------
219class ClientBridge:
220    """
221    RFCOMM client bridge: connects to a BR/EDR device, then waits for an inbound
222    TCP connection on a specified port number. When a TCP client connects, an
223    RFCOMM connection to the device is established, and the data is bridged in both
224    directions, with flow control.
225    When the TCP connection is closed by the client, the RFCOMM channel is
226    disconnected, but the connection to the device remains, ready for a new TCP client
227    to connect.
228    """
229
230    READ_CHUNK_SIZE = 4096
231
232    def __init__(
233        self,
234        channel: int,
235        uuid: str,
236        trace: bool,
237        address: str,
238        tcp_host: str,
239        tcp_port: int,
240        encrypt: bool,
241    ):
242        self.channel = channel
243        self.uuid = uuid
244        self.trace = trace
245        self.address = address
246        self.tcp_host = tcp_host
247        self.tcp_port = tcp_port
248        self.encrypt = encrypt
249        self.device: Optional[Device] = None
250        self.connection: Optional[Connection] = None
251        self.rfcomm_client: Optional[rfcomm.Client]
252        self.rfcomm_mux: Optional[rfcomm.Multiplexer]
253        self.tcp_connected: bool = False
254
255        self.tcp_tracer: Optional[Tracer]
256        self.rfcomm_tracer: Optional[Tracer]
257
258        if trace:
259            self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))
260            self.rfcomm_tracer = Tracer(color("TCP->RFCOMM", "magenta"))
261        else:
262            self.rfcomm_tracer = None
263            self.tcp_tracer = None
264
265    async def connect(self) -> None:
266        if self.connection:
267            return
268
269        print(color(f"@@@ Connecting to Bluetooth {self.address}", "blue"))
270        assert self.device
271        self.connection = await self.device.connect(
272            self.address, transport=core.BT_BR_EDR_TRANSPORT
273        )
274        print(color(f"@@@ Bluetooth connection: {self.connection}", "blue"))
275        self.connection.on("disconnection", self.on_disconnection)
276
277        if self.encrypt:
278            print(color("@@@ Encrypting Bluetooth connection", "blue"))
279            await self.connection.encrypt()
280            print(color("@@@ Bluetooth connection encrypted", "blue"))
281
282        self.rfcomm_client = rfcomm.Client(self.connection)
283        try:
284            self.rfcomm_mux = await self.rfcomm_client.start()
285        except BaseException as e:
286            print(color("!!! Failed to setup RFCOMM connection", "red"), e)
287            raise
288
289    async def start(self, device: Device) -> None:
290        self.device = device
291        await device.set_connectable(False)
292        await device.set_discoverable(False)
293
294        # Called when a TCP connection is established
295        async def on_tcp_connection(reader, writer):
296            print(color("<<< TCP connection", "magenta"))
297            if self.tcp_connected:
298                print(
299                    color("!!! TCP connection already active, rejecting new one", "red")
300                )
301                writer.close()
302                return
303            self.tcp_connected = True
304
305            try:
306                await self.pipe(reader, writer)
307            except BaseException as error:
308                print(color("!!! Exception while piping data:", "red"), error)
309                return
310            finally:
311                writer.close()
312                await writer.wait_closed()
313                self.tcp_connected = False
314
315        await asyncio.start_server(
316            on_tcp_connection,
317            host=self.tcp_host if self.tcp_host != "_" else None,
318            port=self.tcp_port,
319        )
320        print(
321            color(
322                f"### Listening for TCP connections on port {self.tcp_port}", "magenta"
323            )
324        )
325
326    async def pipe(
327        self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
328    ) -> None:
329        # Resolve the channel number from the UUID if needed
330        if self.channel == 0:
331            await self.connect()
332            assert self.connection
333            channel = await rfcomm.find_rfcomm_channel_with_uuid(
334                self.connection, self.uuid
335            )
336            if channel:
337                print(color(f"### Found RFCOMM channel {channel}", "yellow"))
338            else:
339                print(color(f"!!! RFCOMM channel with UUID {self.uuid} not found"))
340                return
341        else:
342            channel = self.channel
343
344        # Connect a new RFCOMM channel
345        await self.connect()
346        assert self.rfcomm_mux
347        print(color(f"*** Opening RFCOMM channel {channel}", "green"))
348        try:
349            rfcomm_channel = await self.rfcomm_mux.open_dlc(channel)
350            print(color(f"*** RFCOMM channel open: {rfcomm_channel}", "green"))
351        except Exception as error:
352            print(color(f"!!! RFCOMM open failed: {error}", "red"))
353            return
354
355        # Pipe data from RFCOMM to TCP
356        def on_rfcomm_channel_closed():
357            print(color("*** RFCOMM channel closed", "green"))
358
359        def write_rfcomm_data(data):
360            if self.trace:
361                self.rfcomm_tracer.trace_data(data)
362
363            writer.write(data)
364
365        rfcomm_channel.on("close", on_rfcomm_channel_closed)
366        rfcomm_channel.sink = write_rfcomm_data
367
368        # Pipe data from TCP to RFCOMM
369        while True:
370            try:
371                data = await reader.read(self.READ_CHUNK_SIZE)
372
373                if len(data) == 0:
374                    print(color("### TCP end of stream", "yellow"))
375                    if rfcomm_channel.state == rfcomm.DLC.State.CONNECTED:
376                        await rfcomm_channel.disconnect()
377                    self.tcp_connected = False
378                    return
379
380                if self.tcp_tracer:
381                    self.tcp_tracer.trace_data(data)
382
383                rfcomm_channel.write(data)
384                await rfcomm_channel.drain()
385            except Exception as error:
386                print(f"!!! Exception: {error}")
387                break
388
389        print(color("~~~ Bye bye", "magenta"))
390
391    def on_disconnection(self, reason: int) -> None:
392        print(
393            color("@@@ Bluetooth disconnection:", "red"),
394            hci.HCI_Constant.error_name(reason),
395        )
396        self.connection = None
397
398
399# -----------------------------------------------------------------------------
400async def run(device_config, hci_transport, bridge):
401    print("<<< connecting to HCI...")
402    async with await transport.open_transport_or_link(hci_transport) as (
403        hci_source,
404        hci_sink,
405    ):
406        print("<<< connected")
407
408        if device_config:
409            device = Device.from_config_file_with_hci(
410                device_config, hci_source, hci_sink
411            )
412        else:
413            device = Device.from_config_with_hci(
414                DeviceConfiguration(), hci_source, hci_sink
415            )
416        device.classic_enabled = True
417
418        # Let's go
419        await device.power_on()
420        try:
421            await bridge.start(device)
422
423            # Wait until the transport terminates
424            await hci_source.wait_for_termination()
425        except core.ConnectionError as error:
426            print(color(f"!!! Bluetooth connection failed: {error}", "red"))
427        except Exception as error:
428            print(f"Exception while running bridge: {error}")
429
430
431# -----------------------------------------------------------------------------
432@click.group()
433@click.pass_context
434@click.option(
435    "--device-config",
436    metavar="CONFIG_FILE",
437    help="Device configuration file",
438)
439@click.option(
440    "--hci-transport", metavar="TRANSPORT_NAME", help="HCI transport", required=True
441)
442@click.option("--trace", is_flag=True, help="Trace bridged data to stdout")
443@click.option(
444    "--channel",
445    metavar="CHANNEL_NUMER",
446    help="RFCOMM channel number",
447    type=int,
448    default=0,
449)
450@click.option(
451    "--uuid",
452    metavar="UUID",
453    help="UUID for the RFCOMM channel",
454    default=DEFAULT_RFCOMM_UUID,
455)
456def cli(
457    context,
458    device_config,
459    hci_transport,
460    trace,
461    channel,
462    uuid,
463):
464    context.ensure_object(dict)
465    context.obj["device_config"] = device_config
466    context.obj["hci_transport"] = hci_transport
467    context.obj["trace"] = trace
468    context.obj["channel"] = channel
469    context.obj["uuid"] = uuid
470
471
472# -----------------------------------------------------------------------------
473@cli.command()
474@click.pass_context
475@click.option("--tcp-host", help="TCP host", default="localhost")
476@click.option("--tcp-port", help="TCP port", default=DEFAULT_SERVER_TCP_PORT)
477def server(context, tcp_host, tcp_port):
478    bridge = ServerBridge(
479        context.obj["channel"],
480        context.obj["uuid"],
481        context.obj["trace"],
482        tcp_host,
483        tcp_port,
484    )
485    asyncio.run(run(context.obj["device_config"], context.obj["hci_transport"], bridge))
486
487
488# -----------------------------------------------------------------------------
489@cli.command()
490@click.pass_context
491@click.argument("bluetooth-address")
492@click.option("--tcp-host", help="TCP host", default="_")
493@click.option("--tcp-port", help="TCP port", default=DEFAULT_CLIENT_TCP_PORT)
494@click.option("--encrypt", is_flag=True, help="Encrypt the connection")
495def client(context, bluetooth_address, tcp_host, tcp_port, encrypt):
496    bridge = ClientBridge(
497        context.obj["channel"],
498        context.obj["uuid"],
499        context.obj["trace"],
500        bluetooth_address,
501        tcp_host,
502        tcp_port,
503        encrypt,
504    )
505    asyncio.run(run(context.obj["device_config"], context.obj["hci_transport"], bridge))
506
507
508# -----------------------------------------------------------------------------
509logging.basicConfig(level=os.environ.get("BUMBLE_LOGLEVEL", "WARNING").upper())
510if __name__ == "__main__":
511    cli(obj={})  # pylint: disable=no-value-for-parameter
512