1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import datetime
19import logging
20import os
21import struct
22
23import click
24
25from bumble.colors import color
26from bumble import hci
27from bumble.transport.common import PacketReader
28from bumble.helpers import PacketTracer
29
30
31# -----------------------------------------------------------------------------
32# Logging
33# -----------------------------------------------------------------------------
34logger = logging.getLogger(__name__)
35
36
37# -----------------------------------------------------------------------------
38# Classes
39# -----------------------------------------------------------------------------
40class SnoopPacketReader:
41    '''
42    Reader that reads HCI packets from a "snoop" file (based on RFC 1761, but not
43    exactly the same...)
44    '''
45
46    DATALINK_H1 = 1001
47    DATALINK_H4 = 1002
48    DATALINK_BSCP = 1003
49    DATALINK_H5 = 1004
50
51    IDENTIFICATION_PATTERN = b'btsnoop\0'
52    TIMESTAMP_ANCHOR = datetime.datetime(2000, 1, 1)
53    TIMESTAMP_DELTA = 0x00E03AB44A676000
54    ONE_MICROSECOND = datetime.timedelta(microseconds=1)
55
56    def __init__(self, source):
57        self.source = source
58        self.at_end = False
59
60        # Read the header
61        identification_pattern = source.read(8)
62        if identification_pattern != self.IDENTIFICATION_PATTERN:
63            raise ValueError(
64                'not a valid snoop file, unexpected identification pattern'
65            )
66        (self.version_number, self.data_link_type) = struct.unpack(
67            '>II', source.read(8)
68        )
69        if self.data_link_type not in (self.DATALINK_H4, self.DATALINK_H1):
70            raise ValueError(f'datalink type {self.data_link_type} not supported')
71
72    def next_packet(self):
73        # Read the record header
74        header = self.source.read(24)
75        if len(header) < 24:
76            self.at_end = True
77            return (None, 0, None)
78
79        # Parse the header
80        (
81            original_length,
82            included_length,
83            packet_flags,
84            _cumulative_drops,
85            timestamp,
86        ) = struct.unpack('>IIIIQ', header)
87
88        # Skip truncated packets
89        if original_length != included_length:
90            print(
91                color(
92                    f"!!! truncated packet ({included_length}/{original_length})", "red"
93                )
94            )
95            self.source.read(included_length)
96            return (None, 0, None)
97
98        # Convert the timestamp to a datetime object.
99        ts_dt = self.TIMESTAMP_ANCHOR + datetime.timedelta(
100            microseconds=timestamp - self.TIMESTAMP_DELTA
101        )
102
103        if self.data_link_type == self.DATALINK_H1:
104            # The packet is un-encapsulated, look at the flags to figure out its type
105            if packet_flags & 1:
106                # Controller -> Host
107                if packet_flags & 2:
108                    packet_type = hci.HCI_EVENT_PACKET
109                else:
110                    packet_type = hci.HCI_ACL_DATA_PACKET
111            else:
112                # Host -> Controller
113                if packet_flags & 2:
114                    packet_type = hci.HCI_COMMAND_PACKET
115                else:
116                    packet_type = hci.HCI_ACL_DATA_PACKET
117
118            return (
119                packet_flags & 1,
120                bytes([packet_type]) + self.source.read(included_length),
121            )
122
123        return (ts_dt, packet_flags & 1, self.source.read(included_length))
124
125
126# -----------------------------------------------------------------------------
127class Printer:
128    def __init__(self):
129        self.index = 0
130
131    def print(self, message: str) -> None:
132        self.index += 1
133        print(f"[{self.index:8}]{message}")
134
135
136# -----------------------------------------------------------------------------
137# Main
138# -----------------------------------------------------------------------------
139@click.command()
140@click.option(
141    '--format',
142    type=click.Choice(['h4', 'snoop']),
143    default='h4',
144    help='Format of the input file',
145)
146@click.option(
147    '--vendors',
148    type=click.Choice(['android', 'zephyr']),
149    multiple=True,
150    help='Support vendor-specific commands (list one or more)',
151)
152@click.argument('filename')
153# pylint: disable=redefined-builtin
154def main(format, vendors, filename):
155    for vendor in vendors:
156        if vendor == 'android':
157            import bumble.vendor.android.hci
158        elif vendor == 'zephyr':
159            import bumble.vendor.zephyr.hci
160
161    input = open(filename, 'rb')
162    if format == 'h4':
163        packet_reader = PacketReader(input)
164
165        def read_next_packet():
166            return (None, 0, packet_reader.next_packet())
167
168    else:
169        packet_reader = SnoopPacketReader(input)
170        read_next_packet = packet_reader.next_packet
171
172    printer = Printer()
173    tracer = PacketTracer(emit_message=printer.print)
174
175    while not packet_reader.at_end:
176        try:
177            (timestamp, direction, packet) = read_next_packet()
178            if packet:
179                tracer.trace(hci.HCI_Packet.from_bytes(packet), direction, timestamp)
180            else:
181                printer.print(color("[TRUNCATED]", "red"))
182        except Exception as error:
183            logger.exception()
184            print(color(f'!!! {error}', 'red'))
185
186
187# -----------------------------------------------------------------------------
188if __name__ == '__main__':
189    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
190    main()  # pylint: disable=no-value-for-parameter
191