1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19import contextlib
20import struct
21import asyncio
22import logging
23import io
24from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict
25
26from bumble import core
27from bumble import hci
28from bumble.colors import color
29from bumble.snoop import Snooper
30
31
32# -----------------------------------------------------------------------------
33# Logging
34# -----------------------------------------------------------------------------
35logger = logging.getLogger(__name__)
36
37# -----------------------------------------------------------------------------
38# Information needed to parse HCI packets with a generic parser:
39# For each packet type, the info represents:
40# (length-size, length-offset, unpack-type)
41HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = {
42    hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
43    hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
44    hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
45    hci.HCI_EVENT_PACKET: (1, 1, 'B'),
46    hci.HCI_ISO_DATA_PACKET: (2, 2, 'H'),
47}
48
49
50# -----------------------------------------------------------------------------
51# Errors
52# -----------------------------------------------------------------------------
53class TransportLostError(core.BaseBumbleError, RuntimeError):
54    """The Transport has been lost/disconnected."""
55
56
57class TransportInitError(core.BaseBumbleError, RuntimeError):
58    """Error raised when the transport cannot be initialized."""
59
60
61class TransportSpecError(core.BaseBumbleError, ValueError):
62    """Error raised when the transport spec is invalid."""
63
64
65# -----------------------------------------------------------------------------
66# Typing Protocols
67# -----------------------------------------------------------------------------
68class TransportSink(Protocol):
69    def on_packet(self, packet: bytes) -> None: ...
70
71
72class TransportSource(Protocol):
73    terminated: asyncio.Future[None]
74
75    def set_packet_sink(self, sink: TransportSink) -> None: ...
76
77
78# -----------------------------------------------------------------------------
79class PacketPump:
80    """
81    Pump HCI packets from a reader to a sink.
82    """
83
84    def __init__(self, reader: AsyncPacketReader, sink: TransportSink) -> None:
85        self.reader = reader
86        self.sink = sink
87
88    async def run(self) -> None:
89        while True:
90            try:
91                # Deliver the packet to the sink
92                self.sink.on_packet(await self.reader.next_packet())
93            except Exception as error:
94                logger.warning(f'!!! {error}')
95
96
97# -----------------------------------------------------------------------------
98class PacketParser:
99    """
100    In-line parser that accepts data and emits 'on_packet' when a full packet has been
101    parsed.
102    """
103
104    # pylint: disable=attribute-defined-outside-init
105
106    NEED_TYPE = 0
107    NEED_LENGTH = 1
108    NEED_BODY = 2
109
110    sink: Optional[TransportSink]
111    extended_packet_info: Dict[int, Tuple[int, int, str]]
112    packet_info: Optional[Tuple[int, int, str]] = None
113
114    def __init__(self, sink: Optional[TransportSink] = None) -> None:
115        self.sink = sink
116        self.extended_packet_info = {}
117        self.reset()
118
119    def reset(self) -> None:
120        self.state = PacketParser.NEED_TYPE
121        self.bytes_needed = 1
122        self.packet = bytearray()
123        self.packet_info = None
124
125    def feed_data(self, data: bytes) -> None:
126        data_offset = 0
127        data_left = len(data)
128        while data_left and self.bytes_needed:
129            consumed = min(self.bytes_needed, data_left)
130            self.packet.extend(data[data_offset : data_offset + consumed])
131            data_offset += consumed
132            data_left -= consumed
133            self.bytes_needed -= consumed
134
135            if self.bytes_needed == 0:
136                if self.state == PacketParser.NEED_TYPE:
137                    packet_type = self.packet[0]
138                    self.packet_info = HCI_PACKET_INFO.get(
139                        packet_type
140                    ) or self.extended_packet_info.get(packet_type)
141                    if self.packet_info is None:
142                        raise core.InvalidPacketError(
143                            f'invalid packet type {packet_type}'
144                        )
145                    self.state = PacketParser.NEED_LENGTH
146                    self.bytes_needed = self.packet_info[0] + self.packet_info[1]
147                elif self.state == PacketParser.NEED_LENGTH:
148                    assert self.packet_info is not None
149                    body_length = struct.unpack_from(
150                        self.packet_info[2], self.packet, 1 + self.packet_info[1]
151                    )[0]
152                    self.bytes_needed = body_length
153                    self.state = PacketParser.NEED_BODY
154
155                # Emit a packet if one is complete
156                if self.state == PacketParser.NEED_BODY and not self.bytes_needed:
157                    if self.sink:
158                        try:
159                            self.sink.on_packet(bytes(self.packet))
160                        except Exception as error:
161                            logger.exception(
162                                color(f'!!! Exception in on_packet: {error}', 'red')
163                            )
164                    self.reset()
165
166    def set_packet_sink(self, sink: TransportSink) -> None:
167        self.sink = sink
168
169
170# -----------------------------------------------------------------------------
171class PacketReader:
172    """
173    Reader that reads HCI packets from a sync source.
174    """
175
176    def __init__(self, source: io.BufferedReader) -> None:
177        self.source = source
178        self.at_end = False
179
180    def next_packet(self) -> Optional[bytes]:
181        # Get the packet type
182        packet_type = self.source.read(1)
183        if len(packet_type) != 1:
184            self.at_end = True
185            return None
186
187        # Get the packet info based on its type
188        packet_info = HCI_PACKET_INFO.get(packet_type[0])
189        if packet_info is None:
190            raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
191
192        # Read the header (that includes the length)
193        header_size = packet_info[0] + packet_info[1]
194        header = self.source.read(header_size)
195        if len(header) != header_size:
196            raise core.InvalidPacketError('packet too short')
197
198        # Read the body
199        body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
200        body = self.source.read(body_length)
201        if len(body) != body_length:
202            raise core.InvalidPacketError('packet too short')
203
204        return packet_type + header + body
205
206
207# -----------------------------------------------------------------------------
208class AsyncPacketReader:
209    """
210    Reader that reads HCI packets from an async source.
211    """
212
213    def __init__(self, source: asyncio.StreamReader) -> None:
214        self.source = source
215
216    async def next_packet(self) -> bytes:
217        # Get the packet type
218        packet_type = await self.source.readexactly(1)
219
220        # Get the packet info based on its type
221        packet_info = HCI_PACKET_INFO.get(packet_type[0])
222        if packet_info is None:
223            raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
224
225        # Read the header (that includes the length)
226        header_size = packet_info[0] + packet_info[1]
227        header = await self.source.readexactly(header_size)
228
229        # Read the body
230        body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
231        body = await self.source.readexactly(body_length)
232
233        return packet_type + header + body
234
235
236# -----------------------------------------------------------------------------
237class AsyncPipeSink:
238    """
239    Sink that forwards packets asynchronously to another sink.
240    """
241
242    def __init__(self, sink: TransportSink) -> None:
243        self.sink = sink
244        self.loop = asyncio.get_running_loop()
245
246    def on_packet(self, packet: bytes) -> None:
247        self.loop.call_soon(self.sink.on_packet, packet)
248
249
250# -----------------------------------------------------------------------------
251class BaseSource:
252    """
253    Base class designed to be subclassed by transport-specific source classes
254    """
255
256    terminated: asyncio.Future[None]
257    sink: Optional[TransportSink]
258
259    def __init__(self) -> None:
260        self.terminated = asyncio.get_running_loop().create_future()
261        self.sink = None
262
263    def set_packet_sink(self, sink: TransportSink) -> None:
264        self.sink = sink
265
266    def on_transport_lost(self) -> None:
267        if not self.terminated.done():
268            self.terminated.set_result(None)
269
270        if self.sink:
271            if hasattr(self.sink, 'on_transport_lost'):
272                self.sink.on_transport_lost()
273
274    async def wait_for_termination(self) -> None:
275        """
276        Convenience method for backward compatibility. Prefer using the `terminated`
277        attribute instead.
278        """
279        return await self.terminated
280
281    def close(self) -> None:
282        pass
283
284
285# -----------------------------------------------------------------------------
286class ParserSource(BaseSource):
287    """
288    Base class for sources that use an HCI parser.
289    """
290
291    parser: PacketParser
292
293    def __init__(self) -> None:
294        super().__init__()
295        self.parser = PacketParser()
296
297    def set_packet_sink(self, sink: TransportSink) -> None:
298        super().set_packet_sink(sink)
299        self.parser.set_packet_sink(sink)
300
301
302# -----------------------------------------------------------------------------
303class StreamPacketSource(asyncio.Protocol, ParserSource):
304    def data_received(self, data: bytes) -> None:
305        self.parser.feed_data(data)
306
307
308# -----------------------------------------------------------------------------
309class StreamPacketSink:
310    def __init__(self, transport: asyncio.WriteTransport) -> None:
311        self.transport = transport
312
313    def on_packet(self, packet: bytes) -> None:
314        self.transport.write(packet)
315
316    def close(self) -> None:
317        self.transport.close()
318
319
320# -----------------------------------------------------------------------------
321class Transport:
322    """
323    Base class for all transports.
324
325    A Transport represents a source and a sink together.
326    An instance must be closed by calling close() when no longer used. Instances
327    implement the ContextManager protocol so that they may be used in a `async with`
328    statement.
329    An instance is iterable. The iterator yields, in order, its source and sink, so
330    that it may be used with a convenient call syntax like:
331
332    async with create_transport() as (source, sink):
333        ...
334    """
335
336    def __init__(self, source: TransportSource, sink: TransportSink) -> None:
337        self.source = source
338        self.sink = sink
339
340    async def __aenter__(self):
341        return self
342
343    async def __aexit__(self, *args):
344        await self.close()
345
346    def __iter__(self):
347        return iter((self.source, self.sink))
348
349    async def close(self) -> None:
350        if hasattr(self.source, 'close'):
351            self.source.close()
352        if hasattr(self.sink, 'close'):
353            self.sink.close()
354
355
356# -----------------------------------------------------------------------------
357class PumpedPacketSource(ParserSource):
358    pump_task: Optional[asyncio.Task[None]]
359
360    def __init__(self, receive) -> None:
361        super().__init__()
362        self.receive_function = receive
363        self.pump_task = None
364
365    def start(self) -> None:
366        async def pump_packets() -> None:
367            while True:
368                try:
369                    packet = await self.receive_function()
370                    self.parser.feed_data(packet)
371                except asyncio.CancelledError:
372                    logger.debug('source pump task done')
373                    self.terminated.set_result(None)
374                    break
375                except Exception as error:
376                    logger.warning(f'exception while waiting for packet: {error}')
377                    self.terminated.set_exception(error)
378                    break
379
380        self.pump_task = asyncio.create_task(pump_packets())
381
382    def close(self) -> None:
383        if self.pump_task:
384            self.pump_task.cancel()
385
386
387# -----------------------------------------------------------------------------
388class PumpedPacketSink:
389    def __init__(self, send):
390        self.send_function = send
391        self.packet_queue = asyncio.Queue()
392        self.pump_task = None
393
394    def on_packet(self, packet: bytes) -> None:
395        self.packet_queue.put_nowait(packet)
396
397    def start(self):
398        async def pump_packets():
399            while True:
400                try:
401                    packet = await self.packet_queue.get()
402                    await self.send_function(packet)
403                except asyncio.CancelledError:
404                    logger.debug('sink pump task done')
405                    break
406                except Exception as error:
407                    logger.warning(f'exception while sending packet: {error}')
408                    break
409
410        self.pump_task = asyncio.create_task(pump_packets())
411
412    def close(self):
413        if self.pump_task:
414            self.pump_task.cancel()
415
416
417# -----------------------------------------------------------------------------
418class PumpedTransport(Transport):
419    source: PumpedPacketSource
420    sink: PumpedPacketSink
421
422    def __init__(
423        self,
424        source: PumpedPacketSource,
425        sink: PumpedPacketSink,
426    ) -> None:
427        super().__init__(source, sink)
428
429    def start(self) -> None:
430        self.source.start()
431        self.sink.start()
432
433
434# -----------------------------------------------------------------------------
435class SnoopingTransport(Transport):
436    """Transport wrapper that snoops on packets to/from a wrapped transport."""
437
438    @staticmethod
439    def create_with(
440        transport: Transport, snooper: ContextManager[Snooper]
441    ) -> SnoopingTransport:
442        """
443        Create an instance given a snooper that works as as context manager.
444
445        The returned instance will exit the snooper context when it is closed.
446        """
447        with contextlib.ExitStack() as exit_stack:
448            return SnoopingTransport(
449                transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
450            )
451        raise core.UnreachableError()  # Satisfy the type checker
452
453    class Source:
454        sink: TransportSink
455
456        @property
457        def metadata(self) -> dict[str, Any]:
458            return getattr(self.source, 'metadata', {})
459
460        def __init__(self, source: TransportSource, snooper: Snooper):
461            self.source = source
462            self.snooper = snooper
463            self.terminated = source.terminated
464
465        def set_packet_sink(self, sink: TransportSink) -> None:
466            self.sink = sink
467            self.source.set_packet_sink(self)
468
469        def on_packet(self, packet: bytes) -> None:
470            self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST)
471            if self.sink:
472                self.sink.on_packet(packet)
473
474    class Sink:
475        def __init__(self, sink: TransportSink, snooper: Snooper) -> None:
476            self.sink = sink
477            self.snooper = snooper
478
479        def on_packet(self, packet: bytes) -> None:
480            self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER)
481            if self.sink:
482                self.sink.on_packet(packet)
483
484    def __init__(
485        self,
486        transport: Transport,
487        snooper: Snooper,
488        close_snooper=None,
489    ) -> None:
490        super().__init__(
491            self.Source(transport.source, snooper), self.Sink(transport.sink, snooper)
492        )
493        self.transport = transport
494        self.close_snooper = close_snooper
495
496    async def close(self):
497        await self.transport.close()
498        if self.close_snooper:
499            self.close_snooper()
500