xref: /btstack/test/sbc/sbc_encoder.py (revision 5665ea35c6aeb92912f2adba24f1fcfce0aa3616)
1#!/usr/bin/env python
2import numpy as np
3import wave
4import struct
5import sys
6from sbc import *
7
8X = np.zeros(80, dtype = np.int16)
9
10
11def fetch_samples_for_next_sbc_frame(fin, frame):
12    nr_audio_frames = frame.nr_blocks * frame.nr_subbands
13    raw_data = fin.readframes(nr_audio_frames) # Returns byte data
14
15    total_samples = nr_audio_frames * frame.nr_channels
16    len_raw_data =  len(raw_data) / 2
17
18    padding = np.zeros(total_samples - len_raw_data, dtype=np.int16)
19
20    fmt = "%ih" % len_raw_data # read signed 2 byte shorts
21
22    frame.pcm =  np.concatenate([np.array(struct.unpack(fmt, raw_data)), padding])
23    del raw_data
24
25
26def sbc_frame_analysis(frame, ch, blk, C):
27    global X
28
29    M = frame.nr_subbands
30    L = 10 * M
31    M2 = 2*M
32    L2 = 2*L
33
34    Z = np.zeros(L)
35    Y = np.zeros(M2)
36    W = np.zeros(shape=(M, M2))
37    S = np.zeros(M)
38
39    for i in range(L-1, M-1, -1):
40        X[i] = X[i-M]
41    for i in range(M-1, -1, -1):
42        X[i] = frame.EX[M-1-i]
43
44    for i in range(L):
45        Z[i] = X[i] * C[i]
46
47    for i in range(M2):
48        for k in range(5):
49            Y[i] += Z[i+k*8]
50
51    for i in range(M):
52        for k in range(M2):
53            W[i][k] = np.cos((i+0.5)*(k-2)*np.pi/M)
54            S[i] += W[i][k] * Y[k]
55
56    for sb in range(M):
57        frame.sb_sample[blk][ch][sb] = S[sb]
58
59def sbc_analysis(frame):
60    if frame.nr_subbands == 4:
61        C = Proto_4_40
62    elif frame.nr_subbands == 8:
63        C = Proto_8_80
64    else:
65        return -1
66
67    frame.sb_sample = np.ndarray(shape=(frame.nr_blocks, frame.nr_channels, frame.nr_subbands))
68    index = 0
69    for ch in range(frame.nr_channels):
70        for blk in range(frame.nr_blocks):
71            for sb in range(frame.nr_subbands):
72                frame.EX[sb] = np.int16(frame.pcm[index])
73                index+=1
74            sbc_frame_analysis(frame, ch, blk, C)
75    return 0
76
77def sbc_encode(frame):
78    err = sbc_analysis(frame)
79    if err >= 0:
80        err = sbc_quantization(frame)
81    return err
82
83def calculate_joint_stereo_signal(frame):
84    sb_sample = np.zeros(shape = (frame.nr_blocks,frame.nr_channels,frame.nr_subbands), dtype = np.uint32)
85    scale_factor = np.zeros(shape=(frame.nr_channels, frame.nr_subbands), dtype = np.int32)
86    scalefactor = np.zeros(shape=(frame.nr_channels, frame.nr_subbands), dtype = np.int32)
87
88    for sb in range(frame.nr_subbands-1):
89        for blk in range(frame.nr_blocks):
90             sb_sample[blk][0][sb] = (frame.sb_sample_f[blk][0][sb] +  frame.sb_sample_f[blk][1][sb]) >> 1
91             sb_sample[blk][1][sb] = (frame.sb_sample_f[blk][0][sb] -  frame.sb_sample_f[blk][1][sb]) >> 1
92
93    for ch in range(frame.nr_channels):
94        for sb in range(frame.nr_subbands-1):
95            frame.scale_factor[ch][sb] = 0
96            frame.scalefactor[ch][sb] = 2
97            for blk in range(frame.nr_blocks):
98                while frame.scalefactor[ch][sb] < abs(frame.sb_sample[blk][ch][sb]):
99                    frame.scale_factor[ch][sb]+=1
100                    frame.scalefactor[ch][sb] *= 2
101
102    for sb in range(frame.nr_subbands-1):
103        if (frame.scalefactor[0][sb] + frame.scalefactor[1][sb]) > (scalefactor[0][sb] + scalefactor[1][sb]):
104            frame.join[sb] = 1
105            frame.scale_factor[0][sb] = scale_factor[0][sb]
106            frame.scale_factor[1][sb] = scale_factor[1][sb]
107            frame.scalefactor[0][sb]  = scalefactor[0][sb]
108            frame.scalefactor[1][sb]  = scalefactor[1][sb]
109            for blk in range(frame.nr_blocks):
110                frame.sb_sample[blk][0][sb] = sb_sample[blk][0][sb]
111                frame.sb_sample[blk][1][sb] = sb_sample[blk][1][sb]
112
113def calculate_scalefactor(max_subbandsample):
114    x = 0
115    while True:
116        y = 1 << x + 1
117        if y > max_subbandsample:
118            break
119        x += 1
120    return (x,y)
121
122
123def sbc_quantization(frame):
124    max_subbandsample = np.zeros(shape = (frame.nr_channels, frame.nr_subbands))
125
126    for blk in range(frame.nr_blocks):
127        for ch in range(frame.nr_channels):
128            for sb in range(frame.nr_subbands):
129                m = abs(frame.sb_sample[blk][ch][sb])
130                if max_subbandsample[ch][sb] < m:
131                    max_subbandsample[ch][sb] = m
132
133
134    for ch in range(frame.nr_channels):
135        for sb in range(frame.nr_subbands):
136            frame.scale_factor[ch][sb] = 0
137            frame.scalefactor[ch][sb] = 2
138            for blk in range(frame.nr_blocks):
139                while frame.scalefactor[ch][sb] < abs(frame.sb_sample[blk][ch][sb]):
140                    frame.scale_factor[ch][sb]+=1
141                    frame.scalefactor[ch][sb] *= 2
142
143            #(frame.scale_factor[ch][sb], frame.scalefactor[ch][sb]) = calculate_scalefactor(max_subbandsample[ch][sb])
144
145    frame.bits = sbc_bit_allocation(frame)
146
147    # Reconstruct the Audio Samples
148    frame.levels = np.zeros(shape=(frame.nr_channels, frame.nr_subbands), dtype = np.int32)
149    for ch in range(frame.nr_channels):
150        for sb in range(frame.nr_subbands):
151            frame.levels[ch][sb] = (1 << frame.bits[ch][sb]) - 1 #pow(2.0, frame.bits[ch][sb]) - 1
152
153    frame.syncword = 156
154    frame.crc_check = calculate_crc(frame)
155
156    frame.join = np.zeros(frame.nr_subbands, dtype = np.uint8)
157    if frame.channel_mode == JOINT_STEREO:
158        calculate_joint_stereo_signal(frame)
159
160    for blk in range(frame.nr_blocks):
161        for ch in range(frame.nr_channels):
162            for sb in range(frame.nr_subbands):
163                if frame.levels[ch][sb] > 0:
164                    SB = frame.sb_sample[blk][ch][sb]
165                    L  = frame.levels[ch][sb]
166                    SF = frame.scalefactor[ch][sb]
167                    frame.audio_sample[blk][ch][sb] = np.uint16(((SB * L / SF    + L) - 1.0)/2.0)
168                else:
169                    frame.audio_sample[blk][ch][sb] = 0
170
171    return 0
172
173def sbc_write_frame(fout, sbc_encoder_frame):
174    stream = frame_to_bitstream(sbc_encoder_frame)
175    barray = bytearray(stream)
176    fout.write(barray)
177
178if __name__ == "__main__":
179    usage = '''
180    Usage:      ./sbc_encoder.py input.wav blocks subbands bitpool allocation_method[0-LOUDNESS,1-SNR]
181    Example:    ./sbc_encoder.py fanfare.wav 16 4 31 0
182    '''
183    nr_blocks = 0
184    nr_subbands = 0
185
186
187    if (len(sys.argv) < 6):
188        print(usage)
189        sys.exit(1)
190    try:
191        infile = sys.argv[1]
192        if not infile.endswith('.wav'):
193            print(usage)
194            sys.exit(1)
195        sbcfile = infile.replace('.wav', '-encoded.sbc')
196
197        nr_blocks = int(sys.argv[2])
198        nr_subbands = int(sys.argv[3])
199        bitpool = int(sys.argv[4])
200        allocation_method = int(sys.argv[5])
201
202        fin = wave.open(infile, 'rb')
203        nr_channels = fin.getnchannels()
204        sampling_frequency = fin.getframerate()
205        nr_audio_frames = fin.getnframes()
206
207        subband_frame_count = 0
208        audio_frame_count = 0
209        nr_samples = nr_blocks * nr_subbands
210        fout = open(sbcfile, 'wb')
211        while audio_frame_count < nr_audio_frames:
212            if subband_frame_count % 200 == 0:
213                print("== Frame %d ==" % (subband_frame_count))
214
215            sbc_encoder_frame = SBCFrame(nr_blocks, nr_subbands, nr_channels, bitpool, sampling_frequency, allocation_method)
216            fetch_samples_for_next_sbc_frame(fin, sbc_encoder_frame)
217
218            sbc_encode(sbc_encoder_frame)
219            sbc_write_frame(fout, sbc_encoder_frame)
220
221            audio_frame_count += nr_samples
222            subband_frame_count += 1
223
224        fin.close()
225        fout.close()
226        print("DONE, WAV file %s encoded into SBC file %s " % (infile, sbcfile))
227
228
229    except IOError as e:
230        print(usage)
231        sys.exit(1)
232
233
234
235
236
237