1# Copyright 2023 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Union
16
17# -----------------------------------------------------------------------------
18# Constants
19# -----------------------------------------------------------------------------
20# fmt: off
21
22WL = [-60, -30, 58, 172, 334, 538, 1198, 3042]
23RL42 = [0, 7, 6, 5, 4, 3, 2, 1, 7, 6, 5, 4, 3, 2, 1, 0]
24ILB = [
25    2048,
26    2093,
27    2139,
28    2186,
29    2233,
30    2282,
31    2332,
32    2383,
33    2435,
34    2489,
35    2543,
36    2599,
37    2656,
38    2714,
39    2774,
40    2834,
41    2896,
42    2960,
43    3025,
44    3091,
45    3158,
46    3228,
47    3298,
48    3371,
49    3444,
50    3520,
51    3597,
52    3676,
53    3756,
54    3838,
55    3922,
56    4008,
57]
58WH = [0, -214, 798]
59RH2 = [2, 1, 2, 1]
60# Values in QM2/QM4/QM6 left shift three bits than original g722 specification.
61QM2 = [-7408, -1616, 7408, 1616]
62QM4 = [
63    0,
64    -20456,
65    -12896,
66    -8968,
67    -6288,
68    -4240,
69    -2584,
70    -1200,
71    20456,
72    12896,
73    8968,
74    6288,
75    4240,
76    2584,
77    1200,
78    0,
79]
80QM6 = [
81    -136,
82    -136,
83    -136,
84    -136,
85    -24808,
86    -21904,
87    -19008,
88    -16704,
89    -14984,
90    -13512,
91    -12280,
92    -11192,
93    -10232,
94    -9360,
95    -8576,
96    -7856,
97    -7192,
98    -6576,
99    -6000,
100    -5456,
101    -4944,
102    -4464,
103    -4008,
104    -3576,
105    -3168,
106    -2776,
107    -2400,
108    -2032,
109    -1688,
110    -1360,
111    -1040,
112    -728,
113    24808,
114    21904,
115    19008,
116    16704,
117    14984,
118    13512,
119    12280,
120    11192,
121    10232,
122    9360,
123    8576,
124    7856,
125    7192,
126    6576,
127    6000,
128    5456,
129    4944,
130    4464,
131    4008,
132    3576,
133    3168,
134    2776,
135    2400,
136    2032,
137    1688,
138    1360,
139    1040,
140    728,
141    432,
142    136,
143    -432,
144    -136,
145]
146QMF_COEFFS = [3, -11, 12, 32, -210, 951, 3876, -805, 362, -156, 53, -11]
147
148# fmt: on
149
150
151# -----------------------------------------------------------------------------
152# Classes
153# -----------------------------------------------------------------------------
154class G722Decoder:
155    """G.722 decoder with bitrate 64kbit/s.
156
157    For the Blocks in the sub-band decoders, please refer to the G.722
158    specification for the required information. G722 specification:
159    https://www.itu.int/rec/T-REC-G.722-201209-I
160    """
161
162    def __init__(self) -> None:
163        self._x = [0] * 24
164        self._band = [Band(), Band()]
165        # The initial value in BLOCK 3L
166        self._band[0].det = 32
167        # The initial value in BLOCK 3H
168        self._band[1].det = 8
169
170    def decode_frame(self, encoded_data: Union[bytes, bytearray]) -> bytearray:
171        result_array = bytearray(len(encoded_data) * 4)
172        self.g722_decode(result_array, encoded_data)
173        return result_array
174
175    def g722_decode(self, result_array, encoded_data: Union[bytes, bytearray]) -> int:
176        """Decode the data frame using g722 decoder."""
177        result_length = 0
178
179        for code in encoded_data:
180            higher_bits = (code >> 6) & 0x03
181            lower_bits = code & 0x3F
182
183            rlow = self.lower_sub_band_decoder(lower_bits)
184            rhigh = self.higher_sub_band_decoder(higher_bits)
185
186            # Apply the receive QMF
187            self._x[:22] = self._x[2:]
188            self._x[22] = rlow + rhigh
189            self._x[23] = rlow - rhigh
190
191            xout2 = sum(self._x[2 * i] * QMF_COEFFS[i] for i in range(12))
192            xout1 = sum(self._x[2 * i + 1] * QMF_COEFFS[11 - i] for i in range(12))
193
194            result_length = self.update_decoded_result(
195                xout1, result_length, result_array
196            )
197            result_length = self.update_decoded_result(
198                xout2, result_length, result_array
199            )
200
201        return result_length
202
203    def update_decoded_result(
204        self, xout: int, byte_length: int, byte_array: bytearray
205    ) -> int:
206        result = (int)(xout >> 11)
207        bytes_result = result.to_bytes(2, 'little', signed=True)
208        byte_array[byte_length] = bytes_result[0]
209        byte_array[byte_length + 1] = bytes_result[1]
210        return byte_length + 2
211
212    def lower_sub_band_decoder(self, lower_bits: int) -> int:
213        """Lower sub-band decoder for last six bits."""
214
215        # Block 5L
216        # INVQBL
217        wd1 = lower_bits
218        wd2 = QM6[wd1]
219        wd1 >>= 2
220        wd2 = (self._band[0].det * wd2) >> 15
221        # RECONS
222        rlow = self._band[0].s + wd2
223
224        # Block 6L
225        # LIMIT
226        if rlow > 16383:
227            rlow = 16383
228        elif rlow < -16384:
229            rlow = -16384
230
231        # Block 2L
232        # INVQAL
233        wd2 = QM4[wd1]
234        dlowt = (self._band[0].det * wd2) >> 15
235
236        # Block 3L
237        # LOGSCL
238        wd2 = RL42[wd1]
239        wd1 = (self._band[0].nb * 127) >> 7
240        wd1 += WL[wd2]
241
242        if wd1 < 0:
243            wd1 = 0
244        elif wd1 > 18432:
245            wd1 = 18432
246
247        self._band[0].nb = wd1
248
249        # SCALEL
250        wd1 = (self._band[0].nb >> 6) & 31
251        wd2 = 8 - (self._band[0].nb >> 11)
252
253        if wd2 < 0:
254            wd3 = ILB[wd1] << -wd2
255        else:
256            wd3 = ILB[wd1] >> wd2
257
258        self._band[0].det = wd3 << 2
259
260        # Block 4L
261        self._band[0].block4(dlowt)
262
263        return rlow
264
265    def higher_sub_band_decoder(self, higher_bits: int) -> int:
266        """Higher sub-band decoder for first two bits."""
267
268        # Block 2H
269        # INVQAH
270        wd2 = QM2[higher_bits]
271        dhigh = (self._band[1].det * wd2) >> 15
272
273        # Block 5H
274        # RECONS
275        rhigh = dhigh + self._band[1].s
276
277        # Block 6H
278        # LIMIT
279        if rhigh > 16383:
280            rhigh = 16383
281        elif rhigh < -16384:
282            rhigh = -16384
283
284        # Block 3H
285        # LOGSCH
286        wd2 = RH2[higher_bits]
287        wd1 = (self._band[1].nb * 127) >> 7
288        wd1 += WH[wd2]
289
290        if wd1 < 0:
291            wd1 = 0
292        elif wd1 > 22528:
293            wd1 = 22528
294        self._band[1].nb = wd1
295
296        # SCALEH
297        wd1 = (self._band[1].nb >> 6) & 31
298        wd2 = 10 - (self._band[1].nb >> 11)
299
300        if wd2 < 0:
301            wd3 = ILB[wd1] << -wd2
302        else:
303            wd3 = ILB[wd1] >> wd2
304        self._band[1].det = wd3 << 2
305
306        # Block 4H
307        self._band[1].block4(dhigh)
308
309        return rhigh
310
311
312# -----------------------------------------------------------------------------
313class Band:
314    """Structure for G722 decode processing."""
315
316    s: int = 0
317    nb: int = 0
318    det: int = 0
319
320    def __init__(self) -> None:
321        self._sp = 0
322        self._sz = 0
323        self._r = [0] * 3
324        self._a = [0] * 3
325        self._ap = [0] * 3
326        self._p = [0] * 3
327        self._d = [0] * 7
328        self._b = [0] * 7
329        self._bp = [0] * 7
330        self._sg = [0] * 7
331
332    def saturate(self, amp: int) -> int:
333        if amp > 32767:
334            return 32767
335        elif amp < -32768:
336            return -32768
337        else:
338            return amp
339
340    def block4(self, d: int) -> None:
341        """Block4 for both lower and higher sub-band decoder."""
342        wd1 = 0
343        wd2 = 0
344        wd3 = 0
345
346        # RECONS
347        self._d[0] = d
348        self._r[0] = self.saturate(self.s + d)
349
350        # PARREC
351        self._p[0] = self.saturate(self._sz + d)
352
353        # UPPOL2
354        for i in range(3):
355            self._sg[i] = (self._p[i]) >> 15
356        wd1 = self.saturate((self._a[1]) << 2)
357        wd2 = -wd1 if self._sg[0] == self._sg[1] else wd1
358
359        if wd2 > 32767:
360            wd2 = 32767
361
362        wd3 = 128 if self._sg[0] == self._sg[2] else -128
363        wd3 += wd2 >> 7
364        wd3 += (self._a[2] * 32512) >> 15
365
366        if wd3 > 12288:
367            wd3 = 12288
368        elif wd3 < -12288:
369            wd3 = -12288
370        self._ap[2] = wd3
371
372        # UPPOL1
373        self._sg[0] = (self._p[0]) >> 15
374        self._sg[1] = (self._p[1]) >> 15
375        wd1 = 192 if self._sg[0] == self._sg[1] else -192
376        wd2 = (self._a[1] * 32640) >> 15
377
378        self._ap[1] = self.saturate(wd1 + wd2)
379        wd3 = self.saturate(15360 - self._ap[2])
380
381        if self._ap[1] > wd3:
382            self._ap[1] = wd3
383        elif self._ap[1] < -wd3:
384            self._ap[1] = -wd3
385
386        # UPZERO
387        wd1 = 0 if d == 0 else 128
388        self._sg[0] = d >> 15
389        for i in range(1, 7):
390            self._sg[i] = (self._d[i]) >> 15
391            wd2 = wd1 if self._sg[i] == self._sg[0] else -wd1
392            wd3 = (self._b[i] * 32640) >> 15
393            self._bp[i] = self.saturate(wd2 + wd3)
394
395        # DELAYA
396        for i in range(6, 0, -1):
397            self._d[i] = self._d[i - 1]
398            self._b[i] = self._bp[i]
399
400        for i in range(2, 0, -1):
401            self._r[i] = self._r[i - 1]
402            self._p[i] = self._p[i - 1]
403            self._a[i] = self._ap[i]
404
405        # FILTEP
406        self._sp = 0
407        for i in range(1, 3):
408            wd1 = self.saturate(self._r[i] + self._r[i])
409            self._sp += (self._a[i] * wd1) >> 15
410        self._sp = self.saturate(self._sp)
411
412        # FILTEZ
413        self._sz = 0
414        for i in range(6, 0, -1):
415            wd1 = self.saturate(self._d[i] + self._d[i])
416            self._sz += (self._b[i] * wd1) >> 15
417        self._sz = self.saturate(self._sz)
418
419        # PREDIC
420        self.s = self.saturate(self._sp + self._sz)
421