xref: /aosp_15_r20/external/liblc3/test/bitstream.py (revision 49fe348c0058011ee60b6957cdd9d52742df84bc)
1*49fe348cSAndroid Build Coastguard Worker#
2*49fe348cSAndroid Build Coastguard Worker# Copyright 2022 Google LLC
3*49fe348cSAndroid Build Coastguard Worker#
4*49fe348cSAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
5*49fe348cSAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
6*49fe348cSAndroid Build Coastguard Worker# You may obtain a copy of the License at
7*49fe348cSAndroid Build Coastguard Worker#
8*49fe348cSAndroid Build Coastguard Worker#     http://www.apache.org/licenses/LICENSE-2.0
9*49fe348cSAndroid Build Coastguard Worker#
10*49fe348cSAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
11*49fe348cSAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
12*49fe348cSAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*49fe348cSAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
14*49fe348cSAndroid Build Coastguard Worker# limitations under the License.
15*49fe348cSAndroid Build Coastguard Worker#
16*49fe348cSAndroid Build Coastguard Worker
17*49fe348cSAndroid Build Coastguard Workerimport math
18*49fe348cSAndroid Build Coastguard Worker
19*49fe348cSAndroid Build Coastguard Workerclass Bitstream:
20*49fe348cSAndroid Build Coastguard Worker
21*49fe348cSAndroid Build Coastguard Worker    def __init__(self, data):
22*49fe348cSAndroid Build Coastguard Worker
23*49fe348cSAndroid Build Coastguard Worker        self.bytes = data
24*49fe348cSAndroid Build Coastguard Worker
25*49fe348cSAndroid Build Coastguard Worker        self.bp_bw = len(data) - 1
26*49fe348cSAndroid Build Coastguard Worker        self.mask_bw = 1
27*49fe348cSAndroid Build Coastguard Worker
28*49fe348cSAndroid Build Coastguard Worker        self.bp = 0
29*49fe348cSAndroid Build Coastguard Worker        self.low = 0
30*49fe348cSAndroid Build Coastguard Worker        self.range = 0xffffff
31*49fe348cSAndroid Build Coastguard Worker
32*49fe348cSAndroid Build Coastguard Worker    def dump(self):
33*49fe348cSAndroid Build Coastguard Worker
34*49fe348cSAndroid Build Coastguard Worker        b = self.bytes
35*49fe348cSAndroid Build Coastguard Worker
36*49fe348cSAndroid Build Coastguard Worker        for i in range(0, len(b), 20):
37*49fe348cSAndroid Build Coastguard Worker            print(''.join('{:02x} '.format(x)
38*49fe348cSAndroid Build Coastguard Worker                for x in b[i:min(i+20, len(b))] ))
39*49fe348cSAndroid Build Coastguard Worker
40*49fe348cSAndroid Build Coastguard Workerclass BitstreamReader(Bitstream):
41*49fe348cSAndroid Build Coastguard Worker
42*49fe348cSAndroid Build Coastguard Worker    def __init__(self, data):
43*49fe348cSAndroid Build Coastguard Worker
44*49fe348cSAndroid Build Coastguard Worker        super().__init__(data)
45*49fe348cSAndroid Build Coastguard Worker
46*49fe348cSAndroid Build Coastguard Worker        self.low = ( (self.bytes[0] << 16) |
47*49fe348cSAndroid Build Coastguard Worker                     (self.bytes[1] <<  8) |
48*49fe348cSAndroid Build Coastguard Worker                     (self.bytes[2]      ) )
49*49fe348cSAndroid Build Coastguard Worker        self.bp = 3
50*49fe348cSAndroid Build Coastguard Worker
51*49fe348cSAndroid Build Coastguard Worker    def read_bit(self):
52*49fe348cSAndroid Build Coastguard Worker
53*49fe348cSAndroid Build Coastguard Worker        bit = bool(self.bytes[self.bp_bw] & self.mask_bw)
54*49fe348cSAndroid Build Coastguard Worker
55*49fe348cSAndroid Build Coastguard Worker        self.mask_bw <<= 1
56*49fe348cSAndroid Build Coastguard Worker        if self.mask_bw == 0x100:
57*49fe348cSAndroid Build Coastguard Worker            self.mask_bw = 1
58*49fe348cSAndroid Build Coastguard Worker            self.bp_bw -= 1
59*49fe348cSAndroid Build Coastguard Worker
60*49fe348cSAndroid Build Coastguard Worker        return bit
61*49fe348cSAndroid Build Coastguard Worker
62*49fe348cSAndroid Build Coastguard Worker    def read_uint(self, nbits):
63*49fe348cSAndroid Build Coastguard Worker
64*49fe348cSAndroid Build Coastguard Worker        val = 0
65*49fe348cSAndroid Build Coastguard Worker        for k in range(nbits):
66*49fe348cSAndroid Build Coastguard Worker            val |= self.read_bit() << k
67*49fe348cSAndroid Build Coastguard Worker
68*49fe348cSAndroid Build Coastguard Worker        return val
69*49fe348cSAndroid Build Coastguard Worker
70*49fe348cSAndroid Build Coastguard Worker    def ac_decode(self, cum_freqs, sym_freqs):
71*49fe348cSAndroid Build Coastguard Worker
72*49fe348cSAndroid Build Coastguard Worker        r = self.range >> 10
73*49fe348cSAndroid Build Coastguard Worker        if self.low >= r << 10:
74*49fe348cSAndroid Build Coastguard Worker            raise ValueError('Invalid ac bitstream')
75*49fe348cSAndroid Build Coastguard Worker
76*49fe348cSAndroid Build Coastguard Worker        val = len(cum_freqs) - 1
77*49fe348cSAndroid Build Coastguard Worker        while self.low < r * cum_freqs[val]:
78*49fe348cSAndroid Build Coastguard Worker            val -= 1
79*49fe348cSAndroid Build Coastguard Worker
80*49fe348cSAndroid Build Coastguard Worker        self.low -= r * cum_freqs[val]
81*49fe348cSAndroid Build Coastguard Worker        self.range = r * sym_freqs[val]
82*49fe348cSAndroid Build Coastguard Worker        while self.range < 0x10000:
83*49fe348cSAndroid Build Coastguard Worker            self.range <<= 8
84*49fe348cSAndroid Build Coastguard Worker
85*49fe348cSAndroid Build Coastguard Worker            self.low <<= 8
86*49fe348cSAndroid Build Coastguard Worker            self.low &= 0xffffff
87*49fe348cSAndroid Build Coastguard Worker            self.low += self.bytes[self.bp]
88*49fe348cSAndroid Build Coastguard Worker            self.bp += 1
89*49fe348cSAndroid Build Coastguard Worker
90*49fe348cSAndroid Build Coastguard Worker        return val
91*49fe348cSAndroid Build Coastguard Worker
92*49fe348cSAndroid Build Coastguard Worker    def get_bits_left(self):
93*49fe348cSAndroid Build Coastguard Worker
94*49fe348cSAndroid Build Coastguard Worker        nbits = 8 * len(self.bytes)
95*49fe348cSAndroid Build Coastguard Worker
96*49fe348cSAndroid Build Coastguard Worker        nbits_bw = nbits - \
97*49fe348cSAndroid Build Coastguard Worker            (8*self.bp_bw + 8 - int(math.log2(self.mask_bw)))
98*49fe348cSAndroid Build Coastguard Worker
99*49fe348cSAndroid Build Coastguard Worker        nbits_ac = 8 * (self.bp - 3) + \
100*49fe348cSAndroid Build Coastguard Worker            (25 - int(math.floor(math.log2(self.range))))
101*49fe348cSAndroid Build Coastguard Worker
102*49fe348cSAndroid Build Coastguard Worker        return nbits - (nbits_bw + nbits_ac)
103*49fe348cSAndroid Build Coastguard Worker
104*49fe348cSAndroid Build Coastguard Workerclass BitstreamWriter(Bitstream):
105*49fe348cSAndroid Build Coastguard Worker
106*49fe348cSAndroid Build Coastguard Worker    def __init__(self, nbytes):
107*49fe348cSAndroid Build Coastguard Worker
108*49fe348cSAndroid Build Coastguard Worker        super().__init__(bytearray(nbytes))
109*49fe348cSAndroid Build Coastguard Worker
110*49fe348cSAndroid Build Coastguard Worker        self.cache = -1
111*49fe348cSAndroid Build Coastguard Worker        self.carry = 0
112*49fe348cSAndroid Build Coastguard Worker        self.carry_count = 0
113*49fe348cSAndroid Build Coastguard Worker
114*49fe348cSAndroid Build Coastguard Worker    def write_bit(self, bit):
115*49fe348cSAndroid Build Coastguard Worker
116*49fe348cSAndroid Build Coastguard Worker        mask = self.mask_bw
117*49fe348cSAndroid Build Coastguard Worker        bp = self.bp_bw
118*49fe348cSAndroid Build Coastguard Worker
119*49fe348cSAndroid Build Coastguard Worker        if bit == 0:
120*49fe348cSAndroid Build Coastguard Worker            self.bytes[bp] &= ~mask
121*49fe348cSAndroid Build Coastguard Worker        else:
122*49fe348cSAndroid Build Coastguard Worker            self.bytes[bp] |= mask
123*49fe348cSAndroid Build Coastguard Worker
124*49fe348cSAndroid Build Coastguard Worker        self.mask_bw <<= 1
125*49fe348cSAndroid Build Coastguard Worker        if self.mask_bw == 0x100:
126*49fe348cSAndroid Build Coastguard Worker            self.mask_bw = 1
127*49fe348cSAndroid Build Coastguard Worker            self.bp_bw -= 1
128*49fe348cSAndroid Build Coastguard Worker
129*49fe348cSAndroid Build Coastguard Worker    def write_uint(self, val, nbits):
130*49fe348cSAndroid Build Coastguard Worker
131*49fe348cSAndroid Build Coastguard Worker        for k in range(nbits):
132*49fe348cSAndroid Build Coastguard Worker            self.write_bit(val & 1)
133*49fe348cSAndroid Build Coastguard Worker            val >>= 1
134*49fe348cSAndroid Build Coastguard Worker
135*49fe348cSAndroid Build Coastguard Worker    def ac_shift(self):
136*49fe348cSAndroid Build Coastguard Worker
137*49fe348cSAndroid Build Coastguard Worker        if self.low < 0xff0000 or self.carry == 1:
138*49fe348cSAndroid Build Coastguard Worker
139*49fe348cSAndroid Build Coastguard Worker            if self.cache >= 0:
140*49fe348cSAndroid Build Coastguard Worker                self.bytes[self.bp] = self.cache + self.carry
141*49fe348cSAndroid Build Coastguard Worker                self.bp += 1
142*49fe348cSAndroid Build Coastguard Worker
143*49fe348cSAndroid Build Coastguard Worker            while self.carry_count > 0:
144*49fe348cSAndroid Build Coastguard Worker                self.bytes[self.bp] = (self.carry + 0xff) & 0xff
145*49fe348cSAndroid Build Coastguard Worker                self.bp += 1
146*49fe348cSAndroid Build Coastguard Worker                self.carry_count -= 1
147*49fe348cSAndroid Build Coastguard Worker
148*49fe348cSAndroid Build Coastguard Worker            self.cache = self.low >> 16
149*49fe348cSAndroid Build Coastguard Worker            self.carry = 0
150*49fe348cSAndroid Build Coastguard Worker
151*49fe348cSAndroid Build Coastguard Worker        else:
152*49fe348cSAndroid Build Coastguard Worker            self.carry_count += 1
153*49fe348cSAndroid Build Coastguard Worker
154*49fe348cSAndroid Build Coastguard Worker        self.low <<= 8
155*49fe348cSAndroid Build Coastguard Worker        self.low &= 0xffffff
156*49fe348cSAndroid Build Coastguard Worker
157*49fe348cSAndroid Build Coastguard Worker    def ac_encode(self, cum_freq, sym_freq):
158*49fe348cSAndroid Build Coastguard Worker
159*49fe348cSAndroid Build Coastguard Worker        r = self.range >> 10
160*49fe348cSAndroid Build Coastguard Worker        self.low += r * cum_freq
161*49fe348cSAndroid Build Coastguard Worker        if (self.low >> 24) != 0:
162*49fe348cSAndroid Build Coastguard Worker            self.carry = 1
163*49fe348cSAndroid Build Coastguard Worker
164*49fe348cSAndroid Build Coastguard Worker        self.low &= 0xffffff
165*49fe348cSAndroid Build Coastguard Worker        self.range = r * sym_freq
166*49fe348cSAndroid Build Coastguard Worker        while self.range < 0x10000:
167*49fe348cSAndroid Build Coastguard Worker            self.range <<= 8;
168*49fe348cSAndroid Build Coastguard Worker            self.ac_shift()
169*49fe348cSAndroid Build Coastguard Worker
170*49fe348cSAndroid Build Coastguard Worker    def get_bits_left(self):
171*49fe348cSAndroid Build Coastguard Worker
172*49fe348cSAndroid Build Coastguard Worker        nbits = 8 * len(self.bytes)
173*49fe348cSAndroid Build Coastguard Worker
174*49fe348cSAndroid Build Coastguard Worker        nbits_bw = nbits - \
175*49fe348cSAndroid Build Coastguard Worker            (8*self.bp_bw + 8 - int(math.log2(self.mask_bw)))
176*49fe348cSAndroid Build Coastguard Worker
177*49fe348cSAndroid Build Coastguard Worker        nbits_ac = 8 * self.bp + (25 - int(math.floor(math.log2(self.range))))
178*49fe348cSAndroid Build Coastguard Worker        if self.cache >= 0:
179*49fe348cSAndroid Build Coastguard Worker            nbits_ac += 8
180*49fe348cSAndroid Build Coastguard Worker        if self.carry_count > 0:
181*49fe348cSAndroid Build Coastguard Worker            nbits_ac += 8 * self.carry_count
182*49fe348cSAndroid Build Coastguard Worker
183*49fe348cSAndroid Build Coastguard Worker        return nbits - (nbits_bw + nbits_ac)
184*49fe348cSAndroid Build Coastguard Worker
185*49fe348cSAndroid Build Coastguard Worker    def terminate(self):
186*49fe348cSAndroid Build Coastguard Worker
187*49fe348cSAndroid Build Coastguard Worker        bits = 1
188*49fe348cSAndroid Build Coastguard Worker        while self.range >> (24 - bits) == 0:
189*49fe348cSAndroid Build Coastguard Worker            bits += 1
190*49fe348cSAndroid Build Coastguard Worker
191*49fe348cSAndroid Build Coastguard Worker        mask = 0xffffff >> bits;
192*49fe348cSAndroid Build Coastguard Worker        val = self.low + mask;
193*49fe348cSAndroid Build Coastguard Worker
194*49fe348cSAndroid Build Coastguard Worker        over1 = val >> 24
195*49fe348cSAndroid Build Coastguard Worker        val &= 0x00ffffff
196*49fe348cSAndroid Build Coastguard Worker        high = self.low + self.range
197*49fe348cSAndroid Build Coastguard Worker        over2 = high >> 24
198*49fe348cSAndroid Build Coastguard Worker        high &= 0x00ffffff
199*49fe348cSAndroid Build Coastguard Worker        val = val & ~mask
200*49fe348cSAndroid Build Coastguard Worker
201*49fe348cSAndroid Build Coastguard Worker        if over1 == over2:
202*49fe348cSAndroid Build Coastguard Worker
203*49fe348cSAndroid Build Coastguard Worker            if val + mask >= high:
204*49fe348cSAndroid Build Coastguard Worker                bits += 1
205*49fe348cSAndroid Build Coastguard Worker                mask >>= 1
206*49fe348cSAndroid Build Coastguard Worker                val = ((self.low + mask) & 0x00ffffff) & ~mask
207*49fe348cSAndroid Build Coastguard Worker
208*49fe348cSAndroid Build Coastguard Worker            if val < self.low:
209*49fe348cSAndroid Build Coastguard Worker                self.carry = 1
210*49fe348cSAndroid Build Coastguard Worker
211*49fe348cSAndroid Build Coastguard Worker        self.low = val
212*49fe348cSAndroid Build Coastguard Worker        while bits > 0:
213*49fe348cSAndroid Build Coastguard Worker            self.ac_shift()
214*49fe348cSAndroid Build Coastguard Worker            bits -= 8
215*49fe348cSAndroid Build Coastguard Worker        bits += 8;
216*49fe348cSAndroid Build Coastguard Worker
217*49fe348cSAndroid Build Coastguard Worker        val = self.cache
218*49fe348cSAndroid Build Coastguard Worker
219*49fe348cSAndroid Build Coastguard Worker        if self.carry_count > 0:
220*49fe348cSAndroid Build Coastguard Worker            self.bytes[self.bp] = self.cache
221*49fe348cSAndroid Build Coastguard Worker            self.bp += 1
222*49fe348cSAndroid Build Coastguard Worker
223*49fe348cSAndroid Build Coastguard Worker            while self.carry_count > 1:
224*49fe348cSAndroid Build Coastguard Worker                self.bytes[self.bp] = 0xff
225*49fe348cSAndroid Build Coastguard Worker                self.bp += 1
226*49fe348cSAndroid Build Coastguard Worker                self.carry_count -= 1
227*49fe348cSAndroid Build Coastguard Worker
228*49fe348cSAndroid Build Coastguard Worker            val = 0xff >> (8 - bits)
229*49fe348cSAndroid Build Coastguard Worker
230*49fe348cSAndroid Build Coastguard Worker        mask = 0x80;
231*49fe348cSAndroid Build Coastguard Worker        for k in range(bits):
232*49fe348cSAndroid Build Coastguard Worker
233*49fe348cSAndroid Build Coastguard Worker            if val & mask == 0:
234*49fe348cSAndroid Build Coastguard Worker                self.bytes[self.bp] &= ~mask
235*49fe348cSAndroid Build Coastguard Worker            else:
236*49fe348cSAndroid Build Coastguard Worker                self.bytes[self.bp] |= mask
237*49fe348cSAndroid Build Coastguard Worker
238*49fe348cSAndroid Build Coastguard Worker            mask >>= 1
239*49fe348cSAndroid Build Coastguard Worker
240*49fe348cSAndroid Build Coastguard Worker        return self.bytes
241