1# Copyright 2023 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19from dataclasses import dataclass
20
21from bumble import core
22
23
24# -----------------------------------------------------------------------------
25class BitReader:
26    """Simple but not optimized bit stream reader."""
27
28    data: bytes
29    bytes_position: int
30    bit_position: int
31    cache: int
32    bits_cached: int
33
34    def __init__(self, data: bytes):
35        self.data = data
36        self.byte_position = 0
37        self.bit_position = 0
38        self.cache = 0
39        self.bits_cached = 0
40
41    def read(self, bits: int) -> int:
42        """ "Read up to 32 bits."""
43
44        if bits > 32:
45            raise core.InvalidArgumentError('maximum read size is 32')
46
47        if self.bits_cached >= bits:
48            # We have enough bits.
49            self.bits_cached -= bits
50            self.bit_position += bits
51            return (self.cache >> self.bits_cached) & ((1 << bits) - 1)
52
53        # Read more cache, up to 32 bits
54        feed_bytes = self.data[self.byte_position : self.byte_position + 4]
55        feed_size = len(feed_bytes)
56        feed_int = int.from_bytes(feed_bytes, byteorder='big')
57        if 8 * feed_size + self.bits_cached < bits:
58            raise core.InvalidArgumentError('trying to read past the data')
59        self.byte_position += feed_size
60
61        # Combine the new cache and the old cache
62        cache = self.cache & ((1 << self.bits_cached) - 1)
63        new_bits = bits - self.bits_cached
64        self.bits_cached = 8 * feed_size - new_bits
65        result = (feed_int >> self.bits_cached) | (cache << new_bits)
66        self.cache = feed_int
67
68        self.bit_position += bits
69        return result
70
71    def read_bytes(self, count: int):
72        if self.bit_position + 8 * count > 8 * len(self.data):
73            raise core.InvalidArgumentError('not enough data')
74
75        if self.bit_position % 8:
76            # Not byte aligned
77            result = bytearray(count)
78            for i in range(count):
79                result[i] = self.read(8)
80            return bytes(result)
81
82        # Byte aligned
83        self.byte_position = self.bit_position // 8
84        self.bits_cached = 0
85        self.cache = 0
86        offset = self.bit_position // 8
87        self.bit_position += 8 * count
88        return self.data[offset : offset + count]
89
90    def bits_left(self) -> int:
91        return (8 * len(self.data)) - self.bit_position
92
93    def skip(self, bits: int) -> None:
94        # Slow, but simple...
95        while bits:
96            if bits > 32:
97                self.read(32)
98                bits -= 32
99            else:
100                self.read(bits)
101                break
102
103
104# -----------------------------------------------------------------------------
105class AacAudioRtpPacket:
106    """AAC payload encapsulated in an RTP packet payload"""
107
108    @staticmethod
109    def latm_value(reader: BitReader) -> int:
110        bytes_for_value = reader.read(2)
111        value = 0
112        for _ in range(bytes_for_value + 1):
113            value = value * 256 + reader.read(8)
114        return value
115
116    @staticmethod
117    def program_config_element(reader: BitReader):
118        raise core.InvalidPacketError('program_config_element not supported')
119
120    @dataclass
121    class GASpecificConfig:
122        def __init__(
123            self, reader: BitReader, channel_configuration: int, audio_object_type: int
124        ) -> None:
125            # GASpecificConfig - ISO/EIC 14496-3 Table 4.1
126            frame_length_flag = reader.read(1)
127            depends_on_core_coder = reader.read(1)
128            if depends_on_core_coder:
129                self.core_coder_delay = reader.read(14)
130            extension_flag = reader.read(1)
131            if not channel_configuration:
132                AacAudioRtpPacket.program_config_element(reader)
133            if audio_object_type in (6, 20):
134                self.layer_nr = reader.read(3)
135            if extension_flag:
136                if audio_object_type == 22:
137                    num_of_sub_frame = reader.read(5)
138                layer_length = reader.read(11)
139                if audio_object_type in (17, 19, 20, 23):
140                    aac_section_data_resilience_flags = reader.read(1)
141                    aac_scale_factor_data_resilience_flags = reader.read(1)
142                    aac_spectral_data_resilience_flags = reader.read(1)
143                extension_flag_3 = reader.read(1)
144                if extension_flag_3 == 1:
145                    raise core.InvalidPacketError('extensionFlag3 == 1 not supported')
146
147    @staticmethod
148    def audio_object_type(reader: BitReader):
149        # GetAudioObjectType - ISO/EIC 14496-3 Table 1.16
150        audio_object_type = reader.read(5)
151        if audio_object_type == 31:
152            audio_object_type = 32 + reader.read(6)
153
154        return audio_object_type
155
156    @dataclass
157    class AudioSpecificConfig:
158        audio_object_type: int
159        sampling_frequency_index: int
160        sampling_frequency: int
161        channel_configuration: int
162        sbr_present_flag: int
163        ps_present_flag: int
164        extension_audio_object_type: int
165        extension_sampling_frequency_index: int
166        extension_sampling_frequency: int
167        extension_channel_configuration: int
168
169        SAMPLING_FREQUENCIES = [
170            96000,
171            88200,
172            64000,
173            48000,
174            44100,
175            32000,
176            24000,
177            22050,
178            16000,
179            12000,
180            11025,
181            8000,
182            7350,
183        ]
184
185        def __init__(self, reader: BitReader) -> None:
186            # AudioSpecificConfig - ISO/EIC 14496-3 Table 1.15
187            self.audio_object_type = AacAudioRtpPacket.audio_object_type(reader)
188            self.sampling_frequency_index = reader.read(4)
189            if self.sampling_frequency_index == 0xF:
190                self.sampling_frequency = reader.read(24)
191            else:
192                self.sampling_frequency = self.SAMPLING_FREQUENCIES[
193                    self.sampling_frequency_index
194                ]
195            self.channel_configuration = reader.read(4)
196            self.sbr_present_flag = -1
197            self.ps_present_flag = -1
198            if self.audio_object_type in (5, 29):
199                self.extension_audio_object_type = 5
200                self.sbc_present_flag = 1
201                if self.audio_object_type == 29:
202                    self.ps_present_flag = 1
203                self.extension_sampling_frequency_index = reader.read(4)
204                if self.extension_sampling_frequency_index == 0xF:
205                    self.extension_sampling_frequency = reader.read(24)
206                else:
207                    self.extension_sampling_frequency = self.SAMPLING_FREQUENCIES[
208                        self.extension_sampling_frequency_index
209                    ]
210                self.audio_object_type = AacAudioRtpPacket.audio_object_type(reader)
211                if self.audio_object_type == 22:
212                    self.extension_channel_configuration = reader.read(4)
213            else:
214                self.extension_audio_object_type = 0
215
216            if self.audio_object_type in (1, 2, 3, 4, 6, 7, 17, 19, 20, 21, 22, 23):
217                ga_specific_config = AacAudioRtpPacket.GASpecificConfig(
218                    reader, self.channel_configuration, self.audio_object_type
219                )
220            else:
221                raise core.InvalidPacketError(
222                    f'audioObjectType {self.audio_object_type} not supported'
223                )
224
225            # if self.extension_audio_object_type != 5 and bits_to_decode >= 16:
226            #     sync_extension_type = reader.read(11)
227            #     if sync_extension_type == 0x2B7:
228            #         self.extension_audio_object_type = AacAudioRtpPacket.audio_object_type(reader)
229            #         if self.extension_audio_object_type == 5:
230            #             self.sbr_present_flag = reader.read(1)
231            #             if self.sbr_present_flag:
232            #                 self.extension_sampling_frequency_index = reader.read(4)
233            #                 if self.extension_sampling_frequency_index == 0xF:
234            #                     self.extension_sampling_frequency = reader.read(24)
235            #                 else:
236            #                     self.extension_sampling_frequency = self.SAMPLING_FREQUENCIES[self.extension_sampling_frequency_index]
237            #                 if bits_to_decode >= 12:
238            #                     sync_extension_type = reader.read(11)
239            #                     if sync_extension_type == 0x548:
240            #                         self.ps_present_flag = reader.read(1)
241            #         elif self.extension_audio_object_type == 22:
242            #             self.sbr_present_flag = reader.read(1)
243            #             if self.sbr_present_flag:
244            #                 self.extension_sampling_frequency_index = reader.read(4)
245            #                 if self.extension_sampling_frequency_index == 0xF:
246            #                     self.extension_sampling_frequency = reader.read(24)
247            #                 else:
248            #                     self.extension_sampling_frequency = self.SAMPLING_FREQUENCIES[self.extension_sampling_frequency_index]
249            #             self.extension_channel_configuration = reader.read(4)
250
251    @dataclass
252    class StreamMuxConfig:
253        other_data_present: int
254        other_data_len_bits: int
255        audio_specific_config: AacAudioRtpPacket.AudioSpecificConfig
256
257        def __init__(self, reader: BitReader) -> None:
258            # StreamMuxConfig - ISO/EIC 14496-3 Table 1.42
259            audio_mux_version = reader.read(1)
260            if audio_mux_version == 1:
261                audio_mux_version_a = reader.read(1)
262            else:
263                audio_mux_version_a = 0
264            if audio_mux_version_a != 0:
265                raise core.InvalidPacketError('audioMuxVersionA != 0 not supported')
266            if audio_mux_version == 1:
267                tara_buffer_fullness = AacAudioRtpPacket.latm_value(reader)
268            stream_cnt = 0
269            all_streams_same_time_framing = reader.read(1)
270            num_sub_frames = reader.read(6)
271            num_program = reader.read(4)
272            if num_program != 0:
273                raise core.InvalidPacketError('num_program != 0 not supported')
274            num_layer = reader.read(3)
275            if num_layer != 0:
276                raise core.InvalidPacketError('num_layer != 0 not supported')
277            if audio_mux_version == 0:
278                self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
279                    reader
280                )
281            else:
282                asc_len = AacAudioRtpPacket.latm_value(reader)
283                marker = reader.bit_position
284                self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
285                    reader
286                )
287                audio_specific_config_len = reader.bit_position - marker
288                if asc_len < audio_specific_config_len:
289                    raise core.InvalidPacketError('audio_specific_config_len > asc_len')
290                asc_len -= audio_specific_config_len
291                reader.skip(asc_len)
292            frame_length_type = reader.read(3)
293            if frame_length_type == 0:
294                latm_buffer_fullness = reader.read(8)
295            elif frame_length_type == 1:
296                frame_length = reader.read(9)
297            else:
298                raise core.InvalidPacketError(
299                    f'frame_length_type {frame_length_type} not supported'
300                )
301
302            self.other_data_present = reader.read(1)
303            if self.other_data_present:
304                if audio_mux_version == 1:
305                    self.other_data_len_bits = AacAudioRtpPacket.latm_value(reader)
306                else:
307                    self.other_data_len_bits = 0
308                    while True:
309                        self.other_data_len_bits *= 256
310                        other_data_len_esc = reader.read(1)
311                        self.other_data_len_bits += reader.read(8)
312                        if other_data_len_esc == 0:
313                            break
314            crc_check_present = reader.read(1)
315            if crc_check_present:
316                crc_checksum = reader.read(8)
317
318    @dataclass
319    class AudioMuxElement:
320        payload: bytes
321        stream_mux_config: AacAudioRtpPacket.StreamMuxConfig
322
323        def __init__(self, reader: BitReader, mux_config_present: int):
324            if mux_config_present == 0:
325                raise core.InvalidPacketError('muxConfigPresent == 0 not supported')
326
327            # AudioMuxElement - ISO/EIC 14496-3 Table 1.41
328            use_same_stream_mux = reader.read(1)
329            if use_same_stream_mux:
330                raise core.InvalidPacketError('useSameStreamMux == 1 not supported')
331            self.stream_mux_config = AacAudioRtpPacket.StreamMuxConfig(reader)
332
333            # We only support:
334            # allStreamsSameTimeFraming == 1
335            # audioMuxVersionA == 0,
336            # numProgram == 0
337            # numSubFrames == 0
338            # numLayer == 0
339
340            mux_slot_length_bytes = 0
341            while True:
342                tmp = reader.read(8)
343                mux_slot_length_bytes += tmp
344                if tmp != 255:
345                    break
346
347            self.payload = reader.read_bytes(mux_slot_length_bytes)
348
349            if self.stream_mux_config.other_data_present:
350                reader.skip(self.stream_mux_config.other_data_len_bits)
351
352            # ByteAlign
353            while reader.bit_position % 8:
354                reader.read(1)
355
356    def __init__(self, data: bytes) -> None:
357        # Parse the bit stream
358        reader = BitReader(data)
359        self.audio_mux_element = self.AudioMuxElement(reader, mux_config_present=1)
360
361    def to_adts(self):
362        # pylint: disable=line-too-long
363        sampling_frequency_index = (
364            self.audio_mux_element.stream_mux_config.audio_specific_config.sampling_frequency_index
365        )
366        channel_configuration = (
367            self.audio_mux_element.stream_mux_config.audio_specific_config.channel_configuration
368        )
369        frame_size = len(self.audio_mux_element.payload)
370        return (
371            bytes(
372                [
373                    0xFF,
374                    0xF1,  # 0xF9 (MPEG2)
375                    0x40
376                    | (sampling_frequency_index << 2)
377                    | (channel_configuration >> 2),
378                    ((channel_configuration & 0x3) << 6) | ((frame_size + 7) >> 11),
379                    ((frame_size + 7) >> 3) & 0xFF,
380                    (((frame_size + 7) << 5) & 0xFF) | 0x1F,
381                    0xFC,
382                ]
383            )
384            + self.audio_mux_element.payload
385        )
386