1*49fe348cSAndroid Build Coastguard Worker# 2*49fe348cSAndroid Build Coastguard Worker# Copyright 2022 Google LLC 3*49fe348cSAndroid Build Coastguard Worker# 4*49fe348cSAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 5*49fe348cSAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 6*49fe348cSAndroid Build Coastguard Worker# You may obtain a copy of the License at 7*49fe348cSAndroid Build Coastguard Worker# 8*49fe348cSAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 9*49fe348cSAndroid Build Coastguard Worker# 10*49fe348cSAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 11*49fe348cSAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 12*49fe348cSAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*49fe348cSAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 14*49fe348cSAndroid Build Coastguard Worker# limitations under the License. 15*49fe348cSAndroid Build Coastguard Worker# 16*49fe348cSAndroid Build Coastguard Worker 17*49fe348cSAndroid Build Coastguard Workerimport math 18*49fe348cSAndroid Build Coastguard Worker 19*49fe348cSAndroid Build Coastguard Workerclass Bitstream: 20*49fe348cSAndroid Build Coastguard Worker 21*49fe348cSAndroid Build Coastguard Worker def __init__(self, data): 22*49fe348cSAndroid Build Coastguard Worker 23*49fe348cSAndroid Build Coastguard Worker self.bytes = data 24*49fe348cSAndroid Build Coastguard Worker 25*49fe348cSAndroid Build Coastguard Worker self.bp_bw = len(data) - 1 26*49fe348cSAndroid Build Coastguard Worker self.mask_bw = 1 27*49fe348cSAndroid Build Coastguard Worker 28*49fe348cSAndroid Build Coastguard Worker self.bp = 0 29*49fe348cSAndroid Build Coastguard Worker self.low = 0 30*49fe348cSAndroid Build Coastguard Worker self.range = 0xffffff 31*49fe348cSAndroid Build Coastguard Worker 32*49fe348cSAndroid Build Coastguard Worker def dump(self): 33*49fe348cSAndroid Build Coastguard Worker 34*49fe348cSAndroid Build Coastguard Worker b = self.bytes 35*49fe348cSAndroid Build Coastguard Worker 36*49fe348cSAndroid Build Coastguard Worker for i in range(0, len(b), 20): 37*49fe348cSAndroid Build Coastguard Worker print(''.join('{:02x} '.format(x) 38*49fe348cSAndroid Build Coastguard Worker for x in b[i:min(i+20, len(b))] )) 39*49fe348cSAndroid Build Coastguard Worker 40*49fe348cSAndroid Build Coastguard Workerclass BitstreamReader(Bitstream): 41*49fe348cSAndroid Build Coastguard Worker 42*49fe348cSAndroid Build Coastguard Worker def __init__(self, data): 43*49fe348cSAndroid Build Coastguard Worker 44*49fe348cSAndroid Build Coastguard Worker super().__init__(data) 45*49fe348cSAndroid Build Coastguard Worker 46*49fe348cSAndroid Build Coastguard Worker self.low = ( (self.bytes[0] << 16) | 47*49fe348cSAndroid Build Coastguard Worker (self.bytes[1] << 8) | 48*49fe348cSAndroid Build Coastguard Worker (self.bytes[2] ) ) 49*49fe348cSAndroid Build Coastguard Worker self.bp = 3 50*49fe348cSAndroid Build Coastguard Worker 51*49fe348cSAndroid Build Coastguard Worker def read_bit(self): 52*49fe348cSAndroid Build Coastguard Worker 53*49fe348cSAndroid Build Coastguard Worker bit = bool(self.bytes[self.bp_bw] & self.mask_bw) 54*49fe348cSAndroid Build Coastguard Worker 55*49fe348cSAndroid Build Coastguard Worker self.mask_bw <<= 1 56*49fe348cSAndroid Build Coastguard Worker if self.mask_bw == 0x100: 57*49fe348cSAndroid Build Coastguard Worker self.mask_bw = 1 58*49fe348cSAndroid Build Coastguard Worker self.bp_bw -= 1 59*49fe348cSAndroid Build Coastguard Worker 60*49fe348cSAndroid Build Coastguard Worker return bit 61*49fe348cSAndroid Build Coastguard Worker 62*49fe348cSAndroid Build Coastguard Worker def read_uint(self, nbits): 63*49fe348cSAndroid Build Coastguard Worker 64*49fe348cSAndroid Build Coastguard Worker val = 0 65*49fe348cSAndroid Build Coastguard Worker for k in range(nbits): 66*49fe348cSAndroid Build Coastguard Worker val |= self.read_bit() << k 67*49fe348cSAndroid Build Coastguard Worker 68*49fe348cSAndroid Build Coastguard Worker return val 69*49fe348cSAndroid Build Coastguard Worker 70*49fe348cSAndroid Build Coastguard Worker def ac_decode(self, cum_freqs, sym_freqs): 71*49fe348cSAndroid Build Coastguard Worker 72*49fe348cSAndroid Build Coastguard Worker r = self.range >> 10 73*49fe348cSAndroid Build Coastguard Worker if self.low >= r << 10: 74*49fe348cSAndroid Build Coastguard Worker raise ValueError('Invalid ac bitstream') 75*49fe348cSAndroid Build Coastguard Worker 76*49fe348cSAndroid Build Coastguard Worker val = len(cum_freqs) - 1 77*49fe348cSAndroid Build Coastguard Worker while self.low < r * cum_freqs[val]: 78*49fe348cSAndroid Build Coastguard Worker val -= 1 79*49fe348cSAndroid Build Coastguard Worker 80*49fe348cSAndroid Build Coastguard Worker self.low -= r * cum_freqs[val] 81*49fe348cSAndroid Build Coastguard Worker self.range = r * sym_freqs[val] 82*49fe348cSAndroid Build Coastguard Worker while self.range < 0x10000: 83*49fe348cSAndroid Build Coastguard Worker self.range <<= 8 84*49fe348cSAndroid Build Coastguard Worker 85*49fe348cSAndroid Build Coastguard Worker self.low <<= 8 86*49fe348cSAndroid Build Coastguard Worker self.low &= 0xffffff 87*49fe348cSAndroid Build Coastguard Worker self.low += self.bytes[self.bp] 88*49fe348cSAndroid Build Coastguard Worker self.bp += 1 89*49fe348cSAndroid Build Coastguard Worker 90*49fe348cSAndroid Build Coastguard Worker return val 91*49fe348cSAndroid Build Coastguard Worker 92*49fe348cSAndroid Build Coastguard Worker def get_bits_left(self): 93*49fe348cSAndroid Build Coastguard Worker 94*49fe348cSAndroid Build Coastguard Worker nbits = 8 * len(self.bytes) 95*49fe348cSAndroid Build Coastguard Worker 96*49fe348cSAndroid Build Coastguard Worker nbits_bw = nbits - \ 97*49fe348cSAndroid Build Coastguard Worker (8*self.bp_bw + 8 - int(math.log2(self.mask_bw))) 98*49fe348cSAndroid Build Coastguard Worker 99*49fe348cSAndroid Build Coastguard Worker nbits_ac = 8 * (self.bp - 3) + \ 100*49fe348cSAndroid Build Coastguard Worker (25 - int(math.floor(math.log2(self.range)))) 101*49fe348cSAndroid Build Coastguard Worker 102*49fe348cSAndroid Build Coastguard Worker return nbits - (nbits_bw + nbits_ac) 103*49fe348cSAndroid Build Coastguard Worker 104*49fe348cSAndroid Build Coastguard Workerclass BitstreamWriter(Bitstream): 105*49fe348cSAndroid Build Coastguard Worker 106*49fe348cSAndroid Build Coastguard Worker def __init__(self, nbytes): 107*49fe348cSAndroid Build Coastguard Worker 108*49fe348cSAndroid Build Coastguard Worker super().__init__(bytearray(nbytes)) 109*49fe348cSAndroid Build Coastguard Worker 110*49fe348cSAndroid Build Coastguard Worker self.cache = -1 111*49fe348cSAndroid Build Coastguard Worker self.carry = 0 112*49fe348cSAndroid Build Coastguard Worker self.carry_count = 0 113*49fe348cSAndroid Build Coastguard Worker 114*49fe348cSAndroid Build Coastguard Worker def write_bit(self, bit): 115*49fe348cSAndroid Build Coastguard Worker 116*49fe348cSAndroid Build Coastguard Worker mask = self.mask_bw 117*49fe348cSAndroid Build Coastguard Worker bp = self.bp_bw 118*49fe348cSAndroid Build Coastguard Worker 119*49fe348cSAndroid Build Coastguard Worker if bit == 0: 120*49fe348cSAndroid Build Coastguard Worker self.bytes[bp] &= ~mask 121*49fe348cSAndroid Build Coastguard Worker else: 122*49fe348cSAndroid Build Coastguard Worker self.bytes[bp] |= mask 123*49fe348cSAndroid Build Coastguard Worker 124*49fe348cSAndroid Build Coastguard Worker self.mask_bw <<= 1 125*49fe348cSAndroid Build Coastguard Worker if self.mask_bw == 0x100: 126*49fe348cSAndroid Build Coastguard Worker self.mask_bw = 1 127*49fe348cSAndroid Build Coastguard Worker self.bp_bw -= 1 128*49fe348cSAndroid Build Coastguard Worker 129*49fe348cSAndroid Build Coastguard Worker def write_uint(self, val, nbits): 130*49fe348cSAndroid Build Coastguard Worker 131*49fe348cSAndroid Build Coastguard Worker for k in range(nbits): 132*49fe348cSAndroid Build Coastguard Worker self.write_bit(val & 1) 133*49fe348cSAndroid Build Coastguard Worker val >>= 1 134*49fe348cSAndroid Build Coastguard Worker 135*49fe348cSAndroid Build Coastguard Worker def ac_shift(self): 136*49fe348cSAndroid Build Coastguard Worker 137*49fe348cSAndroid Build Coastguard Worker if self.low < 0xff0000 or self.carry == 1: 138*49fe348cSAndroid Build Coastguard Worker 139*49fe348cSAndroid Build Coastguard Worker if self.cache >= 0: 140*49fe348cSAndroid Build Coastguard Worker self.bytes[self.bp] = self.cache + self.carry 141*49fe348cSAndroid Build Coastguard Worker self.bp += 1 142*49fe348cSAndroid Build Coastguard Worker 143*49fe348cSAndroid Build Coastguard Worker while self.carry_count > 0: 144*49fe348cSAndroid Build Coastguard Worker self.bytes[self.bp] = (self.carry + 0xff) & 0xff 145*49fe348cSAndroid Build Coastguard Worker self.bp += 1 146*49fe348cSAndroid Build Coastguard Worker self.carry_count -= 1 147*49fe348cSAndroid Build Coastguard Worker 148*49fe348cSAndroid Build Coastguard Worker self.cache = self.low >> 16 149*49fe348cSAndroid Build Coastguard Worker self.carry = 0 150*49fe348cSAndroid Build Coastguard Worker 151*49fe348cSAndroid Build Coastguard Worker else: 152*49fe348cSAndroid Build Coastguard Worker self.carry_count += 1 153*49fe348cSAndroid Build Coastguard Worker 154*49fe348cSAndroid Build Coastguard Worker self.low <<= 8 155*49fe348cSAndroid Build Coastguard Worker self.low &= 0xffffff 156*49fe348cSAndroid Build Coastguard Worker 157*49fe348cSAndroid Build Coastguard Worker def ac_encode(self, cum_freq, sym_freq): 158*49fe348cSAndroid Build Coastguard Worker 159*49fe348cSAndroid Build Coastguard Worker r = self.range >> 10 160*49fe348cSAndroid Build Coastguard Worker self.low += r * cum_freq 161*49fe348cSAndroid Build Coastguard Worker if (self.low >> 24) != 0: 162*49fe348cSAndroid Build Coastguard Worker self.carry = 1 163*49fe348cSAndroid Build Coastguard Worker 164*49fe348cSAndroid Build Coastguard Worker self.low &= 0xffffff 165*49fe348cSAndroid Build Coastguard Worker self.range = r * sym_freq 166*49fe348cSAndroid Build Coastguard Worker while self.range < 0x10000: 167*49fe348cSAndroid Build Coastguard Worker self.range <<= 8; 168*49fe348cSAndroid Build Coastguard Worker self.ac_shift() 169*49fe348cSAndroid Build Coastguard Worker 170*49fe348cSAndroid Build Coastguard Worker def get_bits_left(self): 171*49fe348cSAndroid Build Coastguard Worker 172*49fe348cSAndroid Build Coastguard Worker nbits = 8 * len(self.bytes) 173*49fe348cSAndroid Build Coastguard Worker 174*49fe348cSAndroid Build Coastguard Worker nbits_bw = nbits - \ 175*49fe348cSAndroid Build Coastguard Worker (8*self.bp_bw + 8 - int(math.log2(self.mask_bw))) 176*49fe348cSAndroid Build Coastguard Worker 177*49fe348cSAndroid Build Coastguard Worker nbits_ac = 8 * self.bp + (25 - int(math.floor(math.log2(self.range)))) 178*49fe348cSAndroid Build Coastguard Worker if self.cache >= 0: 179*49fe348cSAndroid Build Coastguard Worker nbits_ac += 8 180*49fe348cSAndroid Build Coastguard Worker if self.carry_count > 0: 181*49fe348cSAndroid Build Coastguard Worker nbits_ac += 8 * self.carry_count 182*49fe348cSAndroid Build Coastguard Worker 183*49fe348cSAndroid Build Coastguard Worker return nbits - (nbits_bw + nbits_ac) 184*49fe348cSAndroid Build Coastguard Worker 185*49fe348cSAndroid Build Coastguard Worker def terminate(self): 186*49fe348cSAndroid Build Coastguard Worker 187*49fe348cSAndroid Build Coastguard Worker bits = 1 188*49fe348cSAndroid Build Coastguard Worker while self.range >> (24 - bits) == 0: 189*49fe348cSAndroid Build Coastguard Worker bits += 1 190*49fe348cSAndroid Build Coastguard Worker 191*49fe348cSAndroid Build Coastguard Worker mask = 0xffffff >> bits; 192*49fe348cSAndroid Build Coastguard Worker val = self.low + mask; 193*49fe348cSAndroid Build Coastguard Worker 194*49fe348cSAndroid Build Coastguard Worker over1 = val >> 24 195*49fe348cSAndroid Build Coastguard Worker val &= 0x00ffffff 196*49fe348cSAndroid Build Coastguard Worker high = self.low + self.range 197*49fe348cSAndroid Build Coastguard Worker over2 = high >> 24 198*49fe348cSAndroid Build Coastguard Worker high &= 0x00ffffff 199*49fe348cSAndroid Build Coastguard Worker val = val & ~mask 200*49fe348cSAndroid Build Coastguard Worker 201*49fe348cSAndroid Build Coastguard Worker if over1 == over2: 202*49fe348cSAndroid Build Coastguard Worker 203*49fe348cSAndroid Build Coastguard Worker if val + mask >= high: 204*49fe348cSAndroid Build Coastguard Worker bits += 1 205*49fe348cSAndroid Build Coastguard Worker mask >>= 1 206*49fe348cSAndroid Build Coastguard Worker val = ((self.low + mask) & 0x00ffffff) & ~mask 207*49fe348cSAndroid Build Coastguard Worker 208*49fe348cSAndroid Build Coastguard Worker if val < self.low: 209*49fe348cSAndroid Build Coastguard Worker self.carry = 1 210*49fe348cSAndroid Build Coastguard Worker 211*49fe348cSAndroid Build Coastguard Worker self.low = val 212*49fe348cSAndroid Build Coastguard Worker while bits > 0: 213*49fe348cSAndroid Build Coastguard Worker self.ac_shift() 214*49fe348cSAndroid Build Coastguard Worker bits -= 8 215*49fe348cSAndroid Build Coastguard Worker bits += 8; 216*49fe348cSAndroid Build Coastguard Worker 217*49fe348cSAndroid Build Coastguard Worker val = self.cache 218*49fe348cSAndroid Build Coastguard Worker 219*49fe348cSAndroid Build Coastguard Worker if self.carry_count > 0: 220*49fe348cSAndroid Build Coastguard Worker self.bytes[self.bp] = self.cache 221*49fe348cSAndroid Build Coastguard Worker self.bp += 1 222*49fe348cSAndroid Build Coastguard Worker 223*49fe348cSAndroid Build Coastguard Worker while self.carry_count > 1: 224*49fe348cSAndroid Build Coastguard Worker self.bytes[self.bp] = 0xff 225*49fe348cSAndroid Build Coastguard Worker self.bp += 1 226*49fe348cSAndroid Build Coastguard Worker self.carry_count -= 1 227*49fe348cSAndroid Build Coastguard Worker 228*49fe348cSAndroid Build Coastguard Worker val = 0xff >> (8 - bits) 229*49fe348cSAndroid Build Coastguard Worker 230*49fe348cSAndroid Build Coastguard Worker mask = 0x80; 231*49fe348cSAndroid Build Coastguard Worker for k in range(bits): 232*49fe348cSAndroid Build Coastguard Worker 233*49fe348cSAndroid Build Coastguard Worker if val & mask == 0: 234*49fe348cSAndroid Build Coastguard Worker self.bytes[self.bp] &= ~mask 235*49fe348cSAndroid Build Coastguard Worker else: 236*49fe348cSAndroid Build Coastguard Worker self.bytes[self.bp] |= mask 237*49fe348cSAndroid Build Coastguard Worker 238*49fe348cSAndroid Build Coastguard Worker mask >>= 1 239*49fe348cSAndroid Build Coastguard Worker 240*49fe348cSAndroid Build Coastguard Worker return self.bytes 241