1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18from __future__ import annotations 19import contextlib 20import struct 21import asyncio 22import logging 23import io 24from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict 25 26from bumble import core 27from bumble import hci 28from bumble.colors import color 29from bumble.snoop import Snooper 30 31 32# ----------------------------------------------------------------------------- 33# Logging 34# ----------------------------------------------------------------------------- 35logger = logging.getLogger(__name__) 36 37# ----------------------------------------------------------------------------- 38# Information needed to parse HCI packets with a generic parser: 39# For each packet type, the info represents: 40# (length-size, length-offset, unpack-type) 41HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = { 42 hci.HCI_COMMAND_PACKET: (1, 2, 'B'), 43 hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'), 44 hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'), 45 hci.HCI_EVENT_PACKET: (1, 1, 'B'), 46 hci.HCI_ISO_DATA_PACKET: (2, 2, 'H'), 47} 48 49 50# ----------------------------------------------------------------------------- 51# Errors 52# ----------------------------------------------------------------------------- 53class TransportLostError(core.BaseBumbleError, RuntimeError): 54 """The Transport has been lost/disconnected.""" 55 56 57class TransportInitError(core.BaseBumbleError, RuntimeError): 58 """Error raised when the transport cannot be initialized.""" 59 60 61class TransportSpecError(core.BaseBumbleError, ValueError): 62 """Error raised when the transport spec is invalid.""" 63 64 65# ----------------------------------------------------------------------------- 66# Typing Protocols 67# ----------------------------------------------------------------------------- 68class TransportSink(Protocol): 69 def on_packet(self, packet: bytes) -> None: ... 70 71 72class TransportSource(Protocol): 73 terminated: asyncio.Future[None] 74 75 def set_packet_sink(self, sink: TransportSink) -> None: ... 76 77 78# ----------------------------------------------------------------------------- 79class PacketPump: 80 """ 81 Pump HCI packets from a reader to a sink. 82 """ 83 84 def __init__(self, reader: AsyncPacketReader, sink: TransportSink) -> None: 85 self.reader = reader 86 self.sink = sink 87 88 async def run(self) -> None: 89 while True: 90 try: 91 # Deliver the packet to the sink 92 self.sink.on_packet(await self.reader.next_packet()) 93 except Exception as error: 94 logger.warning(f'!!! {error}') 95 96 97# ----------------------------------------------------------------------------- 98class PacketParser: 99 """ 100 In-line parser that accepts data and emits 'on_packet' when a full packet has been 101 parsed. 102 """ 103 104 # pylint: disable=attribute-defined-outside-init 105 106 NEED_TYPE = 0 107 NEED_LENGTH = 1 108 NEED_BODY = 2 109 110 sink: Optional[TransportSink] 111 extended_packet_info: Dict[int, Tuple[int, int, str]] 112 packet_info: Optional[Tuple[int, int, str]] = None 113 114 def __init__(self, sink: Optional[TransportSink] = None) -> None: 115 self.sink = sink 116 self.extended_packet_info = {} 117 self.reset() 118 119 def reset(self) -> None: 120 self.state = PacketParser.NEED_TYPE 121 self.bytes_needed = 1 122 self.packet = bytearray() 123 self.packet_info = None 124 125 def feed_data(self, data: bytes) -> None: 126 data_offset = 0 127 data_left = len(data) 128 while data_left and self.bytes_needed: 129 consumed = min(self.bytes_needed, data_left) 130 self.packet.extend(data[data_offset : data_offset + consumed]) 131 data_offset += consumed 132 data_left -= consumed 133 self.bytes_needed -= consumed 134 135 if self.bytes_needed == 0: 136 if self.state == PacketParser.NEED_TYPE: 137 packet_type = self.packet[0] 138 self.packet_info = HCI_PACKET_INFO.get( 139 packet_type 140 ) or self.extended_packet_info.get(packet_type) 141 if self.packet_info is None: 142 raise core.InvalidPacketError( 143 f'invalid packet type {packet_type}' 144 ) 145 self.state = PacketParser.NEED_LENGTH 146 self.bytes_needed = self.packet_info[0] + self.packet_info[1] 147 elif self.state == PacketParser.NEED_LENGTH: 148 assert self.packet_info is not None 149 body_length = struct.unpack_from( 150 self.packet_info[2], self.packet, 1 + self.packet_info[1] 151 )[0] 152 self.bytes_needed = body_length 153 self.state = PacketParser.NEED_BODY 154 155 # Emit a packet if one is complete 156 if self.state == PacketParser.NEED_BODY and not self.bytes_needed: 157 if self.sink: 158 try: 159 self.sink.on_packet(bytes(self.packet)) 160 except Exception as error: 161 logger.exception( 162 color(f'!!! Exception in on_packet: {error}', 'red') 163 ) 164 self.reset() 165 166 def set_packet_sink(self, sink: TransportSink) -> None: 167 self.sink = sink 168 169 170# ----------------------------------------------------------------------------- 171class PacketReader: 172 """ 173 Reader that reads HCI packets from a sync source. 174 """ 175 176 def __init__(self, source: io.BufferedReader) -> None: 177 self.source = source 178 self.at_end = False 179 180 def next_packet(self) -> Optional[bytes]: 181 # Get the packet type 182 packet_type = self.source.read(1) 183 if len(packet_type) != 1: 184 self.at_end = True 185 return None 186 187 # Get the packet info based on its type 188 packet_info = HCI_PACKET_INFO.get(packet_type[0]) 189 if packet_info is None: 190 raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found') 191 192 # Read the header (that includes the length) 193 header_size = packet_info[0] + packet_info[1] 194 header = self.source.read(header_size) 195 if len(header) != header_size: 196 raise core.InvalidPacketError('packet too short') 197 198 # Read the body 199 body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0] 200 body = self.source.read(body_length) 201 if len(body) != body_length: 202 raise core.InvalidPacketError('packet too short') 203 204 return packet_type + header + body 205 206 207# ----------------------------------------------------------------------------- 208class AsyncPacketReader: 209 """ 210 Reader that reads HCI packets from an async source. 211 """ 212 213 def __init__(self, source: asyncio.StreamReader) -> None: 214 self.source = source 215 216 async def next_packet(self) -> bytes: 217 # Get the packet type 218 packet_type = await self.source.readexactly(1) 219 220 # Get the packet info based on its type 221 packet_info = HCI_PACKET_INFO.get(packet_type[0]) 222 if packet_info is None: 223 raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found') 224 225 # Read the header (that includes the length) 226 header_size = packet_info[0] + packet_info[1] 227 header = await self.source.readexactly(header_size) 228 229 # Read the body 230 body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0] 231 body = await self.source.readexactly(body_length) 232 233 return packet_type + header + body 234 235 236# ----------------------------------------------------------------------------- 237class AsyncPipeSink: 238 """ 239 Sink that forwards packets asynchronously to another sink. 240 """ 241 242 def __init__(self, sink: TransportSink) -> None: 243 self.sink = sink 244 self.loop = asyncio.get_running_loop() 245 246 def on_packet(self, packet: bytes) -> None: 247 self.loop.call_soon(self.sink.on_packet, packet) 248 249 250# ----------------------------------------------------------------------------- 251class BaseSource: 252 """ 253 Base class designed to be subclassed by transport-specific source classes 254 """ 255 256 terminated: asyncio.Future[None] 257 sink: Optional[TransportSink] 258 259 def __init__(self) -> None: 260 self.terminated = asyncio.get_running_loop().create_future() 261 self.sink = None 262 263 def set_packet_sink(self, sink: TransportSink) -> None: 264 self.sink = sink 265 266 def on_transport_lost(self) -> None: 267 if not self.terminated.done(): 268 self.terminated.set_result(None) 269 270 if self.sink: 271 if hasattr(self.sink, 'on_transport_lost'): 272 self.sink.on_transport_lost() 273 274 async def wait_for_termination(self) -> None: 275 """ 276 Convenience method for backward compatibility. Prefer using the `terminated` 277 attribute instead. 278 """ 279 return await self.terminated 280 281 def close(self) -> None: 282 pass 283 284 285# ----------------------------------------------------------------------------- 286class ParserSource(BaseSource): 287 """ 288 Base class for sources that use an HCI parser. 289 """ 290 291 parser: PacketParser 292 293 def __init__(self) -> None: 294 super().__init__() 295 self.parser = PacketParser() 296 297 def set_packet_sink(self, sink: TransportSink) -> None: 298 super().set_packet_sink(sink) 299 self.parser.set_packet_sink(sink) 300 301 302# ----------------------------------------------------------------------------- 303class StreamPacketSource(asyncio.Protocol, ParserSource): 304 def data_received(self, data: bytes) -> None: 305 self.parser.feed_data(data) 306 307 308# ----------------------------------------------------------------------------- 309class StreamPacketSink: 310 def __init__(self, transport: asyncio.WriteTransport) -> None: 311 self.transport = transport 312 313 def on_packet(self, packet: bytes) -> None: 314 self.transport.write(packet) 315 316 def close(self) -> None: 317 self.transport.close() 318 319 320# ----------------------------------------------------------------------------- 321class Transport: 322 """ 323 Base class for all transports. 324 325 A Transport represents a source and a sink together. 326 An instance must be closed by calling close() when no longer used. Instances 327 implement the ContextManager protocol so that they may be used in a `async with` 328 statement. 329 An instance is iterable. The iterator yields, in order, its source and sink, so 330 that it may be used with a convenient call syntax like: 331 332 async with create_transport() as (source, sink): 333 ... 334 """ 335 336 def __init__(self, source: TransportSource, sink: TransportSink) -> None: 337 self.source = source 338 self.sink = sink 339 340 async def __aenter__(self): 341 return self 342 343 async def __aexit__(self, *args): 344 await self.close() 345 346 def __iter__(self): 347 return iter((self.source, self.sink)) 348 349 async def close(self) -> None: 350 if hasattr(self.source, 'close'): 351 self.source.close() 352 if hasattr(self.sink, 'close'): 353 self.sink.close() 354 355 356# ----------------------------------------------------------------------------- 357class PumpedPacketSource(ParserSource): 358 pump_task: Optional[asyncio.Task[None]] 359 360 def __init__(self, receive) -> None: 361 super().__init__() 362 self.receive_function = receive 363 self.pump_task = None 364 365 def start(self) -> None: 366 async def pump_packets() -> None: 367 while True: 368 try: 369 packet = await self.receive_function() 370 self.parser.feed_data(packet) 371 except asyncio.CancelledError: 372 logger.debug('source pump task done') 373 self.terminated.set_result(None) 374 break 375 except Exception as error: 376 logger.warning(f'exception while waiting for packet: {error}') 377 self.terminated.set_exception(error) 378 break 379 380 self.pump_task = asyncio.create_task(pump_packets()) 381 382 def close(self) -> None: 383 if self.pump_task: 384 self.pump_task.cancel() 385 386 387# ----------------------------------------------------------------------------- 388class PumpedPacketSink: 389 def __init__(self, send): 390 self.send_function = send 391 self.packet_queue = asyncio.Queue() 392 self.pump_task = None 393 394 def on_packet(self, packet: bytes) -> None: 395 self.packet_queue.put_nowait(packet) 396 397 def start(self): 398 async def pump_packets(): 399 while True: 400 try: 401 packet = await self.packet_queue.get() 402 await self.send_function(packet) 403 except asyncio.CancelledError: 404 logger.debug('sink pump task done') 405 break 406 except Exception as error: 407 logger.warning(f'exception while sending packet: {error}') 408 break 409 410 self.pump_task = asyncio.create_task(pump_packets()) 411 412 def close(self): 413 if self.pump_task: 414 self.pump_task.cancel() 415 416 417# ----------------------------------------------------------------------------- 418class PumpedTransport(Transport): 419 source: PumpedPacketSource 420 sink: PumpedPacketSink 421 422 def __init__( 423 self, 424 source: PumpedPacketSource, 425 sink: PumpedPacketSink, 426 ) -> None: 427 super().__init__(source, sink) 428 429 def start(self) -> None: 430 self.source.start() 431 self.sink.start() 432 433 434# ----------------------------------------------------------------------------- 435class SnoopingTransport(Transport): 436 """Transport wrapper that snoops on packets to/from a wrapped transport.""" 437 438 @staticmethod 439 def create_with( 440 transport: Transport, snooper: ContextManager[Snooper] 441 ) -> SnoopingTransport: 442 """ 443 Create an instance given a snooper that works as as context manager. 444 445 The returned instance will exit the snooper context when it is closed. 446 """ 447 with contextlib.ExitStack() as exit_stack: 448 return SnoopingTransport( 449 transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close 450 ) 451 raise core.UnreachableError() # Satisfy the type checker 452 453 class Source: 454 sink: TransportSink 455 456 @property 457 def metadata(self) -> dict[str, Any]: 458 return getattr(self.source, 'metadata', {}) 459 460 def __init__(self, source: TransportSource, snooper: Snooper): 461 self.source = source 462 self.snooper = snooper 463 self.terminated = source.terminated 464 465 def set_packet_sink(self, sink: TransportSink) -> None: 466 self.sink = sink 467 self.source.set_packet_sink(self) 468 469 def on_packet(self, packet: bytes) -> None: 470 self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST) 471 if self.sink: 472 self.sink.on_packet(packet) 473 474 class Sink: 475 def __init__(self, sink: TransportSink, snooper: Snooper) -> None: 476 self.sink = sink 477 self.snooper = snooper 478 479 def on_packet(self, packet: bytes) -> None: 480 self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER) 481 if self.sink: 482 self.sink.on_packet(packet) 483 484 def __init__( 485 self, 486 transport: Transport, 487 snooper: Snooper, 488 close_snooper=None, 489 ) -> None: 490 super().__init__( 491 self.Source(transport.source, snooper), self.Sink(transport.sink, snooper) 492 ) 493 self.transport = transport 494 self.close_snooper = close_snooper 495 496 async def close(self): 497 await self.transport.close() 498 if self.close_snooper: 499 self.close_snooper() 500