xref: /btstack/3rd-party/lc3-google/test/bitstream.py (revision 4930cef6e21e6da2d7571b9259c7f0fb8bed3d01)
1*4930cef6SMatthias Ringwald#
2*4930cef6SMatthias Ringwald# Copyright 2022 Google LLC
3*4930cef6SMatthias Ringwald#
4*4930cef6SMatthias Ringwald# Licensed under the Apache License, Version 2.0 (the "License");
5*4930cef6SMatthias Ringwald# you may not use this file except in compliance with the License.
6*4930cef6SMatthias Ringwald# You may obtain a copy of the License at
7*4930cef6SMatthias Ringwald#
8*4930cef6SMatthias Ringwald#     http://www.apache.org/licenses/LICENSE-2.0
9*4930cef6SMatthias Ringwald#
10*4930cef6SMatthias Ringwald# Unless required by applicable law or agreed to in writing, software
11*4930cef6SMatthias Ringwald# distributed under the License is distributed on an "AS IS" BASIS,
12*4930cef6SMatthias Ringwald# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*4930cef6SMatthias Ringwald# See the License for the specific language governing permissions and
14*4930cef6SMatthias Ringwald# limitations under the License.
15*4930cef6SMatthias Ringwald#
16*4930cef6SMatthias Ringwald
17*4930cef6SMatthias Ringwaldimport math
18*4930cef6SMatthias Ringwald
19*4930cef6SMatthias Ringwaldclass Bitstream:
20*4930cef6SMatthias Ringwald
21*4930cef6SMatthias Ringwald    def __init__(self, data):
22*4930cef6SMatthias Ringwald
23*4930cef6SMatthias Ringwald        self.bytes = data
24*4930cef6SMatthias Ringwald
25*4930cef6SMatthias Ringwald        self.bp_bw = len(data) - 1
26*4930cef6SMatthias Ringwald        self.mask_bw = 1
27*4930cef6SMatthias Ringwald
28*4930cef6SMatthias Ringwald        self.bp = 0
29*4930cef6SMatthias Ringwald        self.low = 0
30*4930cef6SMatthias Ringwald        self.range = 0xffffff
31*4930cef6SMatthias Ringwald
32*4930cef6SMatthias Ringwald    def dump(self):
33*4930cef6SMatthias Ringwald
34*4930cef6SMatthias Ringwald        b = self.bytes
35*4930cef6SMatthias Ringwald
36*4930cef6SMatthias Ringwald        for i in range(0, len(b), 20):
37*4930cef6SMatthias Ringwald            print(''.join('{:02x} '.format(x)
38*4930cef6SMatthias Ringwald                for x in b[i:min(i+20, len(b))] ))
39*4930cef6SMatthias Ringwald
40*4930cef6SMatthias Ringwaldclass BitstreamReader(Bitstream):
41*4930cef6SMatthias Ringwald
42*4930cef6SMatthias Ringwald    def __init__(self, data):
43*4930cef6SMatthias Ringwald
44*4930cef6SMatthias Ringwald        super().__init__(data)
45*4930cef6SMatthias Ringwald
46*4930cef6SMatthias Ringwald        self.low = ( (self.bytes[0] << 16) |
47*4930cef6SMatthias Ringwald                     (self.bytes[1] <<  8) |
48*4930cef6SMatthias Ringwald                     (self.bytes[2]      ) )
49*4930cef6SMatthias Ringwald        self.bp = 3
50*4930cef6SMatthias Ringwald
51*4930cef6SMatthias Ringwald    def read_bit(self):
52*4930cef6SMatthias Ringwald
53*4930cef6SMatthias Ringwald        bit = bool(self.bytes[self.bp_bw] & self.mask_bw)
54*4930cef6SMatthias Ringwald
55*4930cef6SMatthias Ringwald        self.mask_bw <<= 1
56*4930cef6SMatthias Ringwald        if self.mask_bw == 0x100:
57*4930cef6SMatthias Ringwald            self.mask_bw = 1
58*4930cef6SMatthias Ringwald            self.bp_bw -= 1
59*4930cef6SMatthias Ringwald
60*4930cef6SMatthias Ringwald        return bit
61*4930cef6SMatthias Ringwald
62*4930cef6SMatthias Ringwald    def read_uint(self, nbits):
63*4930cef6SMatthias Ringwald
64*4930cef6SMatthias Ringwald        val = 0
65*4930cef6SMatthias Ringwald        for k in range(nbits):
66*4930cef6SMatthias Ringwald            val |= self.read_bit() << k
67*4930cef6SMatthias Ringwald
68*4930cef6SMatthias Ringwald        return val
69*4930cef6SMatthias Ringwald
70*4930cef6SMatthias Ringwald    def ac_decode(self, cum_freqs, sym_freqs):
71*4930cef6SMatthias Ringwald
72*4930cef6SMatthias Ringwald        r = self.range >> 10
73*4930cef6SMatthias Ringwald        if self.low >= r << 10:
74*4930cef6SMatthias Ringwald            raise ValueError('Invalid ac bitstream')
75*4930cef6SMatthias Ringwald
76*4930cef6SMatthias Ringwald        val = len(cum_freqs) - 1
77*4930cef6SMatthias Ringwald        while self.low < r * cum_freqs[val]:
78*4930cef6SMatthias Ringwald            val -= 1
79*4930cef6SMatthias Ringwald
80*4930cef6SMatthias Ringwald        self.low -= r * cum_freqs[val]
81*4930cef6SMatthias Ringwald        self.range = r * sym_freqs[val]
82*4930cef6SMatthias Ringwald        while self.range < 0x10000:
83*4930cef6SMatthias Ringwald            self.range <<= 8
84*4930cef6SMatthias Ringwald
85*4930cef6SMatthias Ringwald            self.low <<= 8
86*4930cef6SMatthias Ringwald            self.low &= 0xffffff
87*4930cef6SMatthias Ringwald            self.low += self.bytes[self.bp]
88*4930cef6SMatthias Ringwald            self.bp += 1
89*4930cef6SMatthias Ringwald
90*4930cef6SMatthias Ringwald        return val
91*4930cef6SMatthias Ringwald
92*4930cef6SMatthias Ringwald    def get_bits_left(self):
93*4930cef6SMatthias Ringwald
94*4930cef6SMatthias Ringwald        nbits = 8 * len(self.bytes)
95*4930cef6SMatthias Ringwald
96*4930cef6SMatthias Ringwald        nbits_bw = nbits - \
97*4930cef6SMatthias Ringwald            (8*self.bp_bw + 8 - int(math.log2(self.mask_bw)))
98*4930cef6SMatthias Ringwald
99*4930cef6SMatthias Ringwald        nbits_ac = 8 * (self.bp - 3) + \
100*4930cef6SMatthias Ringwald            (25 - int(math.floor(math.log2(self.range))))
101*4930cef6SMatthias Ringwald
102*4930cef6SMatthias Ringwald        return nbits - (nbits_bw + nbits_ac)
103*4930cef6SMatthias Ringwald
104*4930cef6SMatthias Ringwaldclass BitstreamWriter(Bitstream):
105*4930cef6SMatthias Ringwald
106*4930cef6SMatthias Ringwald    def __init__(self, nbytes):
107*4930cef6SMatthias Ringwald
108*4930cef6SMatthias Ringwald        super().__init__(bytearray(nbytes))
109*4930cef6SMatthias Ringwald
110*4930cef6SMatthias Ringwald        self.cache = -1
111*4930cef6SMatthias Ringwald        self.carry = 0
112*4930cef6SMatthias Ringwald        self.carry_count = 0
113*4930cef6SMatthias Ringwald
114*4930cef6SMatthias Ringwald    def write_bit(self, bit):
115*4930cef6SMatthias Ringwald
116*4930cef6SMatthias Ringwald        mask = self.mask_bw
117*4930cef6SMatthias Ringwald        bp = self.bp_bw
118*4930cef6SMatthias Ringwald
119*4930cef6SMatthias Ringwald        if bit == 0:
120*4930cef6SMatthias Ringwald            self.bytes[bp] &= ~mask
121*4930cef6SMatthias Ringwald        else:
122*4930cef6SMatthias Ringwald            self.bytes[bp] |= mask
123*4930cef6SMatthias Ringwald
124*4930cef6SMatthias Ringwald        self.mask_bw <<= 1
125*4930cef6SMatthias Ringwald        if self.mask_bw == 0x100:
126*4930cef6SMatthias Ringwald            self.mask_bw = 1
127*4930cef6SMatthias Ringwald            self.bp_bw -= 1
128*4930cef6SMatthias Ringwald
129*4930cef6SMatthias Ringwald    def write_uint(self, val, nbits):
130*4930cef6SMatthias Ringwald
131*4930cef6SMatthias Ringwald        for k in range(nbits):
132*4930cef6SMatthias Ringwald            self.write_bit(val & 1)
133*4930cef6SMatthias Ringwald            val >>= 1
134*4930cef6SMatthias Ringwald
135*4930cef6SMatthias Ringwald    def ac_shift(self):
136*4930cef6SMatthias Ringwald
137*4930cef6SMatthias Ringwald        if self.low < 0xff0000 or self.carry == 1:
138*4930cef6SMatthias Ringwald
139*4930cef6SMatthias Ringwald            if self.cache >= 0:
140*4930cef6SMatthias Ringwald                self.bytes[self.bp] = self.cache + self.carry
141*4930cef6SMatthias Ringwald                self.bp += 1
142*4930cef6SMatthias Ringwald
143*4930cef6SMatthias Ringwald            while self.carry_count > 0:
144*4930cef6SMatthias Ringwald                self.bytes[self.bp] = (self.carry + 0xff) & 0xff
145*4930cef6SMatthias Ringwald                self.bp += 1
146*4930cef6SMatthias Ringwald                self.carry_count -= 1
147*4930cef6SMatthias Ringwald
148*4930cef6SMatthias Ringwald            self.cache = self.low >> 16
149*4930cef6SMatthias Ringwald            self.carry = 0
150*4930cef6SMatthias Ringwald
151*4930cef6SMatthias Ringwald        else:
152*4930cef6SMatthias Ringwald            self.carry_count += 1
153*4930cef6SMatthias Ringwald
154*4930cef6SMatthias Ringwald        self.low <<= 8
155*4930cef6SMatthias Ringwald        self.low &= 0xffffff
156*4930cef6SMatthias Ringwald
157*4930cef6SMatthias Ringwald    def ac_encode(self, cum_freq, sym_freq):
158*4930cef6SMatthias Ringwald
159*4930cef6SMatthias Ringwald        r = self.range >> 10
160*4930cef6SMatthias Ringwald        self.low += r * cum_freq
161*4930cef6SMatthias Ringwald        if (self.low >> 24) != 0:
162*4930cef6SMatthias Ringwald            self.carry = 1
163*4930cef6SMatthias Ringwald
164*4930cef6SMatthias Ringwald        self.low &= 0xffffff
165*4930cef6SMatthias Ringwald        self.range = r * sym_freq
166*4930cef6SMatthias Ringwald        while self.range < 0x10000:
167*4930cef6SMatthias Ringwald            self.range <<= 8;
168*4930cef6SMatthias Ringwald            self.ac_shift()
169*4930cef6SMatthias Ringwald
170*4930cef6SMatthias Ringwald    def get_bits_left(self):
171*4930cef6SMatthias Ringwald
172*4930cef6SMatthias Ringwald        nbits = 8 * len(self.bytes)
173*4930cef6SMatthias Ringwald
174*4930cef6SMatthias Ringwald        nbits_bw = nbits - \
175*4930cef6SMatthias Ringwald            (8*self.bp_bw + 8 - int(math.log2(self.mask_bw)))
176*4930cef6SMatthias Ringwald
177*4930cef6SMatthias Ringwald        nbits_ac = 8 * self.bp + (25 - int(math.floor(math.log2(self.range))))
178*4930cef6SMatthias Ringwald        if self.cache >= 0:
179*4930cef6SMatthias Ringwald            nbits_ac += 8
180*4930cef6SMatthias Ringwald        if self.carry_count > 0:
181*4930cef6SMatthias Ringwald            nbits_ac += 8 * self.carry_count
182*4930cef6SMatthias Ringwald
183*4930cef6SMatthias Ringwald        return nbits - (nbits_bw + nbits_ac)
184*4930cef6SMatthias Ringwald
185*4930cef6SMatthias Ringwald    def terminate(self):
186*4930cef6SMatthias Ringwald
187*4930cef6SMatthias Ringwald        bits = 1
188*4930cef6SMatthias Ringwald        while self.range >> (24 - bits) == 0:
189*4930cef6SMatthias Ringwald            bits += 1
190*4930cef6SMatthias Ringwald
191*4930cef6SMatthias Ringwald        mask = 0xffffff >> bits;
192*4930cef6SMatthias Ringwald        val = self.low + mask;
193*4930cef6SMatthias Ringwald
194*4930cef6SMatthias Ringwald        over1 = val >> 24
195*4930cef6SMatthias Ringwald        val &= 0x00ffffff
196*4930cef6SMatthias Ringwald        high = self.low + self.range
197*4930cef6SMatthias Ringwald        over2 = high >> 24
198*4930cef6SMatthias Ringwald        high &= 0x00ffffff
199*4930cef6SMatthias Ringwald        val = val & ~mask
200*4930cef6SMatthias Ringwald
201*4930cef6SMatthias Ringwald        if over1 == over2:
202*4930cef6SMatthias Ringwald
203*4930cef6SMatthias Ringwald            if val + mask >= high:
204*4930cef6SMatthias Ringwald                bits += 1
205*4930cef6SMatthias Ringwald                mask >>= 1
206*4930cef6SMatthias Ringwald                val = ((self.low + mask) & 0x00ffffff) & ~mask
207*4930cef6SMatthias Ringwald
208*4930cef6SMatthias Ringwald            if val < self.low:
209*4930cef6SMatthias Ringwald                self.carry = 1
210*4930cef6SMatthias Ringwald
211*4930cef6SMatthias Ringwald        self.low = val
212*4930cef6SMatthias Ringwald        while bits > 0:
213*4930cef6SMatthias Ringwald            self.ac_shift()
214*4930cef6SMatthias Ringwald            bits -= 8
215*4930cef6SMatthias Ringwald        bits += 8;
216*4930cef6SMatthias Ringwald
217*4930cef6SMatthias Ringwald        val = self.cache
218*4930cef6SMatthias Ringwald
219*4930cef6SMatthias Ringwald        if self.carry_count > 0:
220*4930cef6SMatthias Ringwald            self.bytes[self.bp] = self.cache
221*4930cef6SMatthias Ringwald            self.bp += 1
222*4930cef6SMatthias Ringwald
223*4930cef6SMatthias Ringwald            while self.carry_count > 1:
224*4930cef6SMatthias Ringwald                self.bytes[self.bp] = 0xff
225*4930cef6SMatthias Ringwald                self.bp += 1
226*4930cef6SMatthias Ringwald                self.carry_count -= 1
227*4930cef6SMatthias Ringwald
228*4930cef6SMatthias Ringwald            val = 0xff >> (8 - bits)
229*4930cef6SMatthias Ringwald
230*4930cef6SMatthias Ringwald        mask = 0x80;
231*4930cef6SMatthias Ringwald        for k in range(bits):
232*4930cef6SMatthias Ringwald
233*4930cef6SMatthias Ringwald            if val & mask == 0:
234*4930cef6SMatthias Ringwald                self.bytes[self.bp] &= ~mask
235*4930cef6SMatthias Ringwald            else:
236*4930cef6SMatthias Ringwald                self.bytes[self.bp] |= mask
237*4930cef6SMatthias Ringwald
238*4930cef6SMatthias Ringwald            mask >>= 1
239*4930cef6SMatthias Ringwald
240*4930cef6SMatthias Ringwald        return self.bytes
241