1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18from __future__ import annotations 19 20from collections.abc import Callable, MutableMapping 21import datetime 22from typing import cast, Any, Optional 23import logging 24 25from bumble import avc 26from bumble import avctp 27from bumble import avdtp 28from bumble import avrcp 29from bumble import crypto 30from bumble import rfcomm 31from bumble import sdp 32from bumble.colors import color 33from bumble.att import ATT_CID, ATT_PDU 34from bumble.smp import SMP_CID, SMP_Command 35from bumble.core import name_or_number 36from bumble.l2cap import ( 37 L2CAP_PDU, 38 L2CAP_CONNECTION_REQUEST, 39 L2CAP_CONNECTION_RESPONSE, 40 L2CAP_SIGNALING_CID, 41 L2CAP_LE_SIGNALING_CID, 42 L2CAP_Control_Frame, 43 L2CAP_Connection_Request, 44 L2CAP_Connection_Response, 45) 46from bumble.hci import ( 47 Address, 48 HCI_EVENT_PACKET, 49 HCI_ACL_DATA_PACKET, 50 HCI_DISCONNECTION_COMPLETE_EVENT, 51 HCI_AclDataPacketAssembler, 52 HCI_Packet, 53 HCI_Event, 54 HCI_AclDataPacket, 55 HCI_Disconnection_Complete_Event, 56) 57 58 59# ----------------------------------------------------------------------------- 60# Logging 61# ----------------------------------------------------------------------------- 62logger = logging.getLogger(__name__) 63 64 65# ----------------------------------------------------------------------------- 66PSM_NAMES = { 67 rfcomm.RFCOMM_PSM: 'RFCOMM', 68 sdp.SDP_PSM: 'SDP', 69 avdtp.AVDTP_PSM: 'AVDTP', 70 avctp.AVCTP_PSM: 'AVCTP', 71 # TODO: add more PSM values 72} 73 74AVCTP_PID_NAMES = {avrcp.AVRCP_PID: 'AVRCP'} 75 76 77# ----------------------------------------------------------------------------- 78class PacketTracer: 79 class AclStream: 80 psms: MutableMapping[int, int] 81 peer: Optional[PacketTracer.AclStream] 82 avdtp_assemblers: MutableMapping[int, avdtp.MessageAssembler] 83 avctp_assemblers: MutableMapping[int, avctp.MessageAssembler] 84 85 def __init__(self, analyzer: PacketTracer.Analyzer) -> None: 86 self.analyzer = analyzer 87 self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) 88 self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid 89 self.avctp_assemblers = {} # AVCTP assemblers, by source_cid 90 self.psms = {} # PSM, by source_cid 91 self.peer = None 92 93 # pylint: disable=too-many-nested-blocks 94 def on_acl_pdu(self, pdu: bytes) -> None: 95 l2cap_pdu = L2CAP_PDU.from_bytes(pdu) 96 self.analyzer.emit(l2cap_pdu) 97 98 if l2cap_pdu.cid == ATT_CID: 99 att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload) 100 self.analyzer.emit(att_pdu) 101 elif l2cap_pdu.cid == SMP_CID: 102 smp_command = SMP_Command.from_bytes(l2cap_pdu.payload) 103 self.analyzer.emit(smp_command) 104 elif l2cap_pdu.cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID): 105 control_frame = L2CAP_Control_Frame.from_bytes(l2cap_pdu.payload) 106 self.analyzer.emit(control_frame) 107 108 # Check if this signals a new channel 109 if control_frame.code == L2CAP_CONNECTION_REQUEST: 110 connection_request = cast(L2CAP_Connection_Request, control_frame) 111 self.psms[connection_request.source_cid] = connection_request.psm 112 elif control_frame.code == L2CAP_CONNECTION_RESPONSE: 113 connection_response = cast(L2CAP_Connection_Response, control_frame) 114 if ( 115 connection_response.result 116 == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL 117 ): 118 if self.peer and ( 119 psm := self.peer.psms.get(connection_response.source_cid) 120 ): 121 # Found a pending connection 122 self.psms[connection_response.destination_cid] = psm 123 124 # For AVDTP connections, create a packet assembler for 125 # each direction 126 if psm == avdtp.AVDTP_PSM: 127 self.avdtp_assemblers[ 128 connection_response.source_cid 129 ] = avdtp.MessageAssembler(self.on_avdtp_message) 130 self.peer.avdtp_assemblers[ 131 connection_response.destination_cid 132 ] = avdtp.MessageAssembler(self.peer.on_avdtp_message) 133 elif psm == avctp.AVCTP_PSM: 134 self.avctp_assemblers[ 135 connection_response.source_cid 136 ] = avctp.MessageAssembler(self.on_avctp_message) 137 self.peer.avctp_assemblers[ 138 connection_response.destination_cid 139 ] = avctp.MessageAssembler(self.peer.on_avctp_message) 140 else: 141 # Try to find the PSM associated with this PDU 142 if self.peer and (psm := self.peer.psms.get(l2cap_pdu.cid)): 143 if psm == sdp.SDP_PSM: 144 sdp_pdu = sdp.SDP_PDU.from_bytes(l2cap_pdu.payload) 145 self.analyzer.emit(sdp_pdu) 146 elif psm == rfcomm.RFCOMM_PSM: 147 rfcomm_frame = rfcomm.RFCOMM_Frame.from_bytes(l2cap_pdu.payload) 148 self.analyzer.emit(rfcomm_frame) 149 elif psm == avdtp.AVDTP_PSM: 150 self.analyzer.emit( 151 f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, ' 152 f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}' 153 ) 154 if avdtp_assembler := self.avdtp_assemblers.get(l2cap_pdu.cid): 155 avdtp_assembler.on_pdu(l2cap_pdu.payload) 156 elif psm == avctp.AVCTP_PSM: 157 self.analyzer.emit( 158 f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, ' 159 f'PSM=AVCTP]: {l2cap_pdu.payload.hex()}' 160 ) 161 if avctp_assembler := self.avctp_assemblers.get(l2cap_pdu.cid): 162 avctp_assembler.on_pdu(l2cap_pdu.payload) 163 else: 164 psm_string = name_or_number(PSM_NAMES, psm) 165 self.analyzer.emit( 166 f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, ' 167 f'PSM={psm_string}]: {l2cap_pdu.payload.hex()}' 168 ) 169 else: 170 self.analyzer.emit(l2cap_pdu) 171 172 def on_avdtp_message( 173 self, transaction_label: int, message: avdtp.Message 174 ) -> None: 175 self.analyzer.emit( 176 f'{color("AVDTP", "green")} [{transaction_label}] {message}' 177 ) 178 179 def on_avctp_message( 180 self, 181 transaction_label: int, 182 is_command: bool, 183 ipid: bool, 184 pid: int, 185 payload: bytes, 186 ): 187 if pid == avrcp.AVRCP_PID: 188 avc_frame = avc.Frame.from_bytes(payload) 189 details = str(avc_frame) 190 else: 191 details = payload.hex() 192 193 c_r = 'Command' if is_command else 'Response' 194 self.analyzer.emit( 195 f'{color("AVCTP", "green")} ' 196 f'{c_r}[{transaction_label}][{name_or_number(AVCTP_PID_NAMES, pid)}] ' 197 f'{"#" if ipid else ""}' 198 f'{details}' 199 ) 200 201 def feed_packet(self, packet: HCI_AclDataPacket) -> None: 202 self.packet_assembler.feed_packet(packet) 203 204 class Analyzer: 205 acl_streams: MutableMapping[int, PacketTracer.AclStream] 206 peer: PacketTracer.Analyzer 207 208 def __init__(self, label: str, emit_message: Callable[..., None]) -> None: 209 self.label = label 210 self.emit_message = emit_message 211 self.acl_streams = {} # ACL streams, by connection handle 212 self.packet_timestamp: Optional[datetime.datetime] = None 213 214 def start_acl_stream(self, connection_handle: int) -> PacketTracer.AclStream: 215 logger.info( 216 f'[{self.label}] +++ Creating ACL stream for connection ' 217 f'0x{connection_handle:04X}' 218 ) 219 stream = PacketTracer.AclStream(self) 220 self.acl_streams[connection_handle] = stream 221 222 # Associate with a peer stream if we can 223 if peer_stream := self.peer.acl_streams.get(connection_handle): 224 stream.peer = peer_stream 225 peer_stream.peer = stream 226 227 return stream 228 229 def end_acl_stream(self, connection_handle: int) -> None: 230 if connection_handle in self.acl_streams: 231 logger.info( 232 f'[{self.label}] --- Removing ACL stream for connection ' 233 f'0x{connection_handle:04X}' 234 ) 235 del self.acl_streams[connection_handle] 236 237 # Let the other forwarder know so it can cleanup its stream as well 238 self.peer.end_acl_stream(connection_handle) 239 240 def on_packet( 241 self, timestamp: Optional[datetime.datetime], packet: HCI_Packet 242 ) -> None: 243 self.packet_timestamp = timestamp 244 self.emit(packet) 245 246 if packet.hci_packet_type == HCI_ACL_DATA_PACKET: 247 acl_packet = cast(HCI_AclDataPacket, packet) 248 # Look for an existing stream for this handle, create one if it is the 249 # first ACL packet for that connection handle 250 if ( 251 stream := self.acl_streams.get(acl_packet.connection_handle) 252 ) is None: 253 stream = self.start_acl_stream(acl_packet.connection_handle) 254 stream.feed_packet(acl_packet) 255 elif packet.hci_packet_type == HCI_EVENT_PACKET: 256 event_packet = cast(HCI_Event, packet) 257 if event_packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT: 258 self.end_acl_stream( 259 cast(HCI_Disconnection_Complete_Event, packet).connection_handle 260 ) 261 262 def emit(self, message: Any) -> None: 263 if self.packet_timestamp: 264 prefix = f"[{self.packet_timestamp.strftime('%Y-%m-%d %H:%M:%S.%f')}]" 265 else: 266 prefix = "" 267 self.emit_message(f'{prefix}[{self.label}] {message}') 268 269 def trace( 270 self, 271 packet: HCI_Packet, 272 direction: int = 0, 273 timestamp: Optional[datetime.datetime] = None, 274 ) -> None: 275 if direction == 0: 276 self.host_to_controller_analyzer.on_packet(timestamp, packet) 277 else: 278 self.controller_to_host_analyzer.on_packet(timestamp, packet) 279 280 def __init__( 281 self, 282 host_to_controller_label: str = color('HOST->CONTROLLER', 'blue'), 283 controller_to_host_label: str = color('CONTROLLER->HOST', 'cyan'), 284 emit_message: Callable[..., None] = logger.info, 285 ) -> None: 286 self.host_to_controller_analyzer = PacketTracer.Analyzer( 287 host_to_controller_label, emit_message 288 ) 289 self.controller_to_host_analyzer = PacketTracer.Analyzer( 290 controller_to_host_label, emit_message 291 ) 292 self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer 293 self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer 294 295 296def generate_irk() -> bytes: 297 return crypto.r() 298 299 300def verify_rpa_with_irk(rpa: Address, irk: bytes) -> bool: 301 rpa_bytes = bytes(rpa) 302 prand_given = rpa_bytes[3:] 303 hash_given = rpa_bytes[:3] 304 hash_local = crypto.ah(irk, prand_given) 305 return hash_local[:3] == hash_given 306