1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19
20from collections.abc import Callable, MutableMapping
21import datetime
22from typing import cast, Any, Optional
23import logging
24
25from bumble import avc
26from bumble import avctp
27from bumble import avdtp
28from bumble import avrcp
29from bumble import crypto
30from bumble import rfcomm
31from bumble import sdp
32from bumble.colors import color
33from bumble.att import ATT_CID, ATT_PDU
34from bumble.smp import SMP_CID, SMP_Command
35from bumble.core import name_or_number
36from bumble.l2cap import (
37    L2CAP_PDU,
38    L2CAP_CONNECTION_REQUEST,
39    L2CAP_CONNECTION_RESPONSE,
40    L2CAP_SIGNALING_CID,
41    L2CAP_LE_SIGNALING_CID,
42    L2CAP_Control_Frame,
43    L2CAP_Connection_Request,
44    L2CAP_Connection_Response,
45)
46from bumble.hci import (
47    Address,
48    HCI_EVENT_PACKET,
49    HCI_ACL_DATA_PACKET,
50    HCI_DISCONNECTION_COMPLETE_EVENT,
51    HCI_AclDataPacketAssembler,
52    HCI_Packet,
53    HCI_Event,
54    HCI_AclDataPacket,
55    HCI_Disconnection_Complete_Event,
56)
57
58
59# -----------------------------------------------------------------------------
60# Logging
61# -----------------------------------------------------------------------------
62logger = logging.getLogger(__name__)
63
64
65# -----------------------------------------------------------------------------
66PSM_NAMES = {
67    rfcomm.RFCOMM_PSM: 'RFCOMM',
68    sdp.SDP_PSM: 'SDP',
69    avdtp.AVDTP_PSM: 'AVDTP',
70    avctp.AVCTP_PSM: 'AVCTP',
71    # TODO: add more PSM values
72}
73
74AVCTP_PID_NAMES = {avrcp.AVRCP_PID: 'AVRCP'}
75
76
77# -----------------------------------------------------------------------------
78class PacketTracer:
79    class AclStream:
80        psms: MutableMapping[int, int]
81        peer: Optional[PacketTracer.AclStream]
82        avdtp_assemblers: MutableMapping[int, avdtp.MessageAssembler]
83        avctp_assemblers: MutableMapping[int, avctp.MessageAssembler]
84
85        def __init__(self, analyzer: PacketTracer.Analyzer) -> None:
86            self.analyzer = analyzer
87            self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
88            self.avdtp_assemblers = {}  # AVDTP assemblers, by source_cid
89            self.avctp_assemblers = {}  # AVCTP assemblers, by source_cid
90            self.psms = {}  # PSM, by source_cid
91            self.peer = None
92
93        # pylint: disable=too-many-nested-blocks
94        def on_acl_pdu(self, pdu: bytes) -> None:
95            l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
96            self.analyzer.emit(l2cap_pdu)
97
98            if l2cap_pdu.cid == ATT_CID:
99                att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload)
100                self.analyzer.emit(att_pdu)
101            elif l2cap_pdu.cid == SMP_CID:
102                smp_command = SMP_Command.from_bytes(l2cap_pdu.payload)
103                self.analyzer.emit(smp_command)
104            elif l2cap_pdu.cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
105                control_frame = L2CAP_Control_Frame.from_bytes(l2cap_pdu.payload)
106                self.analyzer.emit(control_frame)
107
108                # Check if this signals a new channel
109                if control_frame.code == L2CAP_CONNECTION_REQUEST:
110                    connection_request = cast(L2CAP_Connection_Request, control_frame)
111                    self.psms[connection_request.source_cid] = connection_request.psm
112                elif control_frame.code == L2CAP_CONNECTION_RESPONSE:
113                    connection_response = cast(L2CAP_Connection_Response, control_frame)
114                    if (
115                        connection_response.result
116                        == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL
117                    ):
118                        if self.peer and (
119                            psm := self.peer.psms.get(connection_response.source_cid)
120                        ):
121                            # Found a pending connection
122                            self.psms[connection_response.destination_cid] = psm
123
124                            # For AVDTP connections, create a packet assembler for
125                            # each direction
126                            if psm == avdtp.AVDTP_PSM:
127                                self.avdtp_assemblers[
128                                    connection_response.source_cid
129                                ] = avdtp.MessageAssembler(self.on_avdtp_message)
130                                self.peer.avdtp_assemblers[
131                                    connection_response.destination_cid
132                                ] = avdtp.MessageAssembler(self.peer.on_avdtp_message)
133                            elif psm == avctp.AVCTP_PSM:
134                                self.avctp_assemblers[
135                                    connection_response.source_cid
136                                ] = avctp.MessageAssembler(self.on_avctp_message)
137                                self.peer.avctp_assemblers[
138                                    connection_response.destination_cid
139                                ] = avctp.MessageAssembler(self.peer.on_avctp_message)
140            else:
141                # Try to find the PSM associated with this PDU
142                if self.peer and (psm := self.peer.psms.get(l2cap_pdu.cid)):
143                    if psm == sdp.SDP_PSM:
144                        sdp_pdu = sdp.SDP_PDU.from_bytes(l2cap_pdu.payload)
145                        self.analyzer.emit(sdp_pdu)
146                    elif psm == rfcomm.RFCOMM_PSM:
147                        rfcomm_frame = rfcomm.RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
148                        self.analyzer.emit(rfcomm_frame)
149                    elif psm == avdtp.AVDTP_PSM:
150                        self.analyzer.emit(
151                            f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
152                            f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}'
153                        )
154                        if avdtp_assembler := self.avdtp_assemblers.get(l2cap_pdu.cid):
155                            avdtp_assembler.on_pdu(l2cap_pdu.payload)
156                    elif psm == avctp.AVCTP_PSM:
157                        self.analyzer.emit(
158                            f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
159                            f'PSM=AVCTP]: {l2cap_pdu.payload.hex()}'
160                        )
161                        if avctp_assembler := self.avctp_assemblers.get(l2cap_pdu.cid):
162                            avctp_assembler.on_pdu(l2cap_pdu.payload)
163                    else:
164                        psm_string = name_or_number(PSM_NAMES, psm)
165                        self.analyzer.emit(
166                            f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
167                            f'PSM={psm_string}]: {l2cap_pdu.payload.hex()}'
168                        )
169                else:
170                    self.analyzer.emit(l2cap_pdu)
171
172        def on_avdtp_message(
173            self, transaction_label: int, message: avdtp.Message
174        ) -> None:
175            self.analyzer.emit(
176                f'{color("AVDTP", "green")} [{transaction_label}] {message}'
177            )
178
179        def on_avctp_message(
180            self,
181            transaction_label: int,
182            is_command: bool,
183            ipid: bool,
184            pid: int,
185            payload: bytes,
186        ):
187            if pid == avrcp.AVRCP_PID:
188                avc_frame = avc.Frame.from_bytes(payload)
189                details = str(avc_frame)
190            else:
191                details = payload.hex()
192
193            c_r = 'Command' if is_command else 'Response'
194            self.analyzer.emit(
195                f'{color("AVCTP", "green")} '
196                f'{c_r}[{transaction_label}][{name_or_number(AVCTP_PID_NAMES, pid)}] '
197                f'{"#" if ipid else ""}'
198                f'{details}'
199            )
200
201        def feed_packet(self, packet: HCI_AclDataPacket) -> None:
202            self.packet_assembler.feed_packet(packet)
203
204    class Analyzer:
205        acl_streams: MutableMapping[int, PacketTracer.AclStream]
206        peer: PacketTracer.Analyzer
207
208        def __init__(self, label: str, emit_message: Callable[..., None]) -> None:
209            self.label = label
210            self.emit_message = emit_message
211            self.acl_streams = {}  # ACL streams, by connection handle
212            self.packet_timestamp: Optional[datetime.datetime] = None
213
214        def start_acl_stream(self, connection_handle: int) -> PacketTracer.AclStream:
215            logger.info(
216                f'[{self.label}] +++ Creating ACL stream for connection '
217                f'0x{connection_handle:04X}'
218            )
219            stream = PacketTracer.AclStream(self)
220            self.acl_streams[connection_handle] = stream
221
222            # Associate with a peer stream if we can
223            if peer_stream := self.peer.acl_streams.get(connection_handle):
224                stream.peer = peer_stream
225                peer_stream.peer = stream
226
227            return stream
228
229        def end_acl_stream(self, connection_handle: int) -> None:
230            if connection_handle in self.acl_streams:
231                logger.info(
232                    f'[{self.label}] --- Removing ACL stream for connection '
233                    f'0x{connection_handle:04X}'
234                )
235                del self.acl_streams[connection_handle]
236
237                # Let the other forwarder know so it can cleanup its stream as well
238                self.peer.end_acl_stream(connection_handle)
239
240        def on_packet(
241            self, timestamp: Optional[datetime.datetime], packet: HCI_Packet
242        ) -> None:
243            self.packet_timestamp = timestamp
244            self.emit(packet)
245
246            if packet.hci_packet_type == HCI_ACL_DATA_PACKET:
247                acl_packet = cast(HCI_AclDataPacket, packet)
248                # Look for an existing stream for this handle, create one if it is the
249                # first ACL packet for that connection handle
250                if (
251                    stream := self.acl_streams.get(acl_packet.connection_handle)
252                ) is None:
253                    stream = self.start_acl_stream(acl_packet.connection_handle)
254                stream.feed_packet(acl_packet)
255            elif packet.hci_packet_type == HCI_EVENT_PACKET:
256                event_packet = cast(HCI_Event, packet)
257                if event_packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT:
258                    self.end_acl_stream(
259                        cast(HCI_Disconnection_Complete_Event, packet).connection_handle
260                    )
261
262        def emit(self, message: Any) -> None:
263            if self.packet_timestamp:
264                prefix = f"[{self.packet_timestamp.strftime('%Y-%m-%d %H:%M:%S.%f')}]"
265            else:
266                prefix = ""
267            self.emit_message(f'{prefix}[{self.label}] {message}')
268
269    def trace(
270        self,
271        packet: HCI_Packet,
272        direction: int = 0,
273        timestamp: Optional[datetime.datetime] = None,
274    ) -> None:
275        if direction == 0:
276            self.host_to_controller_analyzer.on_packet(timestamp, packet)
277        else:
278            self.controller_to_host_analyzer.on_packet(timestamp, packet)
279
280    def __init__(
281        self,
282        host_to_controller_label: str = color('HOST->CONTROLLER', 'blue'),
283        controller_to_host_label: str = color('CONTROLLER->HOST', 'cyan'),
284        emit_message: Callable[..., None] = logger.info,
285    ) -> None:
286        self.host_to_controller_analyzer = PacketTracer.Analyzer(
287            host_to_controller_label, emit_message
288        )
289        self.controller_to_host_analyzer = PacketTracer.Analyzer(
290            controller_to_host_label, emit_message
291        )
292        self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer
293        self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer
294
295
296def generate_irk() -> bytes:
297    return crypto.r()
298
299
300def verify_rpa_with_irk(rpa: Address, irk: bytes) -> bool:
301    rpa_bytes = bytes(rpa)
302    prand_given = rpa_bytes[3:]
303    hash_given = rpa_bytes[:3]
304    hash_local = crypto.ah(irk, prand_given)
305    return hash_local[:3] == hash_given
306