1*a58d3d2aSXin Li""" 2*a58d3d2aSXin Li/* Copyright (c) 2022 Amazon 3*a58d3d2aSXin Li Written by Jan Buethe and Jean-Marc Valin */ 4*a58d3d2aSXin Li/* 5*a58d3d2aSXin Li Redistribution and use in source and binary forms, with or without 6*a58d3d2aSXin Li modification, are permitted provided that the following conditions 7*a58d3d2aSXin Li are met: 8*a58d3d2aSXin Li 9*a58d3d2aSXin Li - Redistributions of source code must retain the above copyright 10*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer. 11*a58d3d2aSXin Li 12*a58d3d2aSXin Li - Redistributions in binary form must reproduce the above copyright 13*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer in the 14*a58d3d2aSXin Li documentation and/or other materials provided with the distribution. 15*a58d3d2aSXin Li 16*a58d3d2aSXin Li THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17*a58d3d2aSXin Li ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18*a58d3d2aSXin Li LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19*a58d3d2aSXin Li A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20*a58d3d2aSXin Li OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21*a58d3d2aSXin Li EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22*a58d3d2aSXin Li PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23*a58d3d2aSXin Li PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24*a58d3d2aSXin Li LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25*a58d3d2aSXin Li NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26*a58d3d2aSXin Li SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*a58d3d2aSXin Li*/ 28*a58d3d2aSXin Li""" 29*a58d3d2aSXin Liimport os 30*a58d3d2aSXin Liimport subprocess 31*a58d3d2aSXin Liimport argparse 32*a58d3d2aSXin Li 33*a58d3d2aSXin Li 34*a58d3d2aSXin Liimport numpy as np 35*a58d3d2aSXin Lifrom scipy.io import wavfile 36*a58d3d2aSXin Liimport tensorflow as tf 37*a58d3d2aSXin Li 38*a58d3d2aSXin Lifrom rdovae import new_rdovae_model, pvq_quantize, apply_dead_zone, sq_rate_metric 39*a58d3d2aSXin Lifrom fec_packets import write_fec_packets, read_fec_packets 40*a58d3d2aSXin Li 41*a58d3d2aSXin Li 42*a58d3d2aSXin Lidebug = False 43*a58d3d2aSXin Li 44*a58d3d2aSXin Liif debug: 45*a58d3d2aSXin Li args = type('dummy', (object,), 46*a58d3d2aSXin Li { 47*a58d3d2aSXin Li 'input' : 'item1.wav', 48*a58d3d2aSXin Li 'weights' : 'testout/rdovae_alignment_fix_1024_120.h5', 49*a58d3d2aSXin Li 'enc_lambda' : 0.0007, 50*a58d3d2aSXin Li 'output' : "test_0007.fec", 51*a58d3d2aSXin Li 'cond_size' : 1024, 52*a58d3d2aSXin Li 'num_redundancy_frames' : 64, 53*a58d3d2aSXin Li 'extra_delay' : 0, 54*a58d3d2aSXin Li 'dump_data' : './dump_data' 55*a58d3d2aSXin Li })() 56*a58d3d2aSXin Li os.environ['CUDA_VISIBLE_DEVICES']="" 57*a58d3d2aSXin Lielse: 58*a58d3d2aSXin Li parser = argparse.ArgumentParser(description='Encode redundancy for Opus neural FEC. Designed for use with voip application and 20ms frames') 59*a58d3d2aSXin Li 60*a58d3d2aSXin Li parser.add_argument('input', metavar='<input signal>', help='audio input (.wav or .raw or .pcm as int16)') 61*a58d3d2aSXin Li parser.add_argument('weights', metavar='<weights>', help='trained model file (.h5)') 62*a58d3d2aSXin Li# parser.add_argument('enc_lambda', metavar='<lambda>', type=float, help='lambda for controlling encoder rate') 63*a58d3d2aSXin Li parser.add_argument('output', type=str, help='output file (will be extended with .fec)') 64*a58d3d2aSXin Li 65*a58d3d2aSXin Li parser.add_argument('--dump-data', type=str, default='./dump_data', help='path to dump data executable (default ./dump_data)') 66*a58d3d2aSXin Li parser.add_argument('--cond-size', metavar='<units>', default=1024, type=int, help='number of units in conditioning network (default 1024)') 67*a58d3d2aSXin Li parser.add_argument('--quant-levels', type=int, help="number of quantization steps (default: 40)", default=40) 68*a58d3d2aSXin Li parser.add_argument('--num-redundancy-frames', default=64, type=int, help='number of redundancy frames (20ms) per packet (default 64)') 69*a58d3d2aSXin Li parser.add_argument('--extra-delay', default=0, type=int, help="last features in packet are calculated with the decoder aligned samples, use this option to add extra delay (in samples at 16kHz)") 70*a58d3d2aSXin Li parser.add_argument('--lossfile', type=str, help='file containing loss trace (0 for frame received, 1 for lost)') 71*a58d3d2aSXin Li 72*a58d3d2aSXin Li parser.add_argument('--debug-output', action='store_true', help='if set, differently assembled features are written to disk') 73*a58d3d2aSXin Li 74*a58d3d2aSXin Li args = parser.parse_args() 75*a58d3d2aSXin Li 76*a58d3d2aSXin Limodel, encoder, decoder, qembedding = new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=1, nb_quant=args.quant_levels, cond_size=args.cond_size) 77*a58d3d2aSXin Limodel.load_weights(args.weights) 78*a58d3d2aSXin Li 79*a58d3d2aSXin Lilpc_order = 16 80*a58d3d2aSXin Li 81*a58d3d2aSXin Li## prepare input signal 82*a58d3d2aSXin Li# SILK frame size is 20ms and LPCNet subframes are 10ms 83*a58d3d2aSXin Lisubframe_size = 160 84*a58d3d2aSXin Liframe_size = 2 * subframe_size 85*a58d3d2aSXin Li 86*a58d3d2aSXin Li# 91 samples delay to align with SILK decoded frames 87*a58d3d2aSXin Lisilk_delay = 91 88*a58d3d2aSXin Li 89*a58d3d2aSXin Li# prepend zeros to have enough history to produce the first package 90*a58d3d2aSXin Lizero_history = (args.num_redundancy_frames - 1) * frame_size 91*a58d3d2aSXin Li 92*a58d3d2aSXin Li# dump data has a (feature) delay of 10ms 93*a58d3d2aSXin Lidump_data_delay = 160 94*a58d3d2aSXin Li 95*a58d3d2aSXin Litotal_delay = silk_delay + zero_history + args.extra_delay - dump_data_delay 96*a58d3d2aSXin Li 97*a58d3d2aSXin Li# load signal 98*a58d3d2aSXin Liif args.input.endswith('.raw') or args.input.endswith('.pcm') or args.input.endswith('.sw'): 99*a58d3d2aSXin Li signal = np.fromfile(args.input, dtype='int16') 100*a58d3d2aSXin Li 101*a58d3d2aSXin Lielif args.input.endswith('.wav'): 102*a58d3d2aSXin Li fs, signal = wavfile.read(args.input) 103*a58d3d2aSXin Lielse: 104*a58d3d2aSXin Li raise ValueError(f'unknown input signal format: {args.input}') 105*a58d3d2aSXin Li 106*a58d3d2aSXin Li# fill up last frame with zeros 107*a58d3d2aSXin Lipadded_signal_length = len(signal) + total_delay 108*a58d3d2aSXin Litail = padded_signal_length % frame_size 109*a58d3d2aSXin Liright_padding = (frame_size - tail) % frame_size 110*a58d3d2aSXin Li 111*a58d3d2aSXin Lisignal = np.concatenate((np.zeros(total_delay, dtype=np.int16), signal, np.zeros(right_padding, dtype=np.int16))) 112*a58d3d2aSXin Li 113*a58d3d2aSXin Lipadded_signal_file = os.path.splitext(args.input)[0] + '_padded.raw' 114*a58d3d2aSXin Lisignal.tofile(padded_signal_file) 115*a58d3d2aSXin Li 116*a58d3d2aSXin Li# write signal and call dump_data to create features 117*a58d3d2aSXin Li 118*a58d3d2aSXin Lifeature_file = os.path.splitext(args.input)[0] + '_features.f32' 119*a58d3d2aSXin Licommand = f"{args.dump_data} -test {padded_signal_file} {feature_file}" 120*a58d3d2aSXin Lir = subprocess.run(command, shell=True) 121*a58d3d2aSXin Liif r.returncode != 0: 122*a58d3d2aSXin Li raise RuntimeError(f"command '{command}' failed with exit code {r.returncode}") 123*a58d3d2aSXin Li 124*a58d3d2aSXin Li# load features 125*a58d3d2aSXin Linb_features = model.nb_used_features + lpc_order 126*a58d3d2aSXin Linb_used_features = model.nb_used_features 127*a58d3d2aSXin Li 128*a58d3d2aSXin Li# load features 129*a58d3d2aSXin Lifeatures = np.fromfile(feature_file, dtype='float32') 130*a58d3d2aSXin Linum_subframes = len(features) // nb_features 131*a58d3d2aSXin Linum_subframes = 2 * (num_subframes // 2) 132*a58d3d2aSXin Linum_frames = num_subframes // 2 133*a58d3d2aSXin Li 134*a58d3d2aSXin Lifeatures = np.reshape(features, (1, -1, nb_features)) 135*a58d3d2aSXin Lifeatures = features[:, :, :nb_used_features] 136*a58d3d2aSXin Lifeatures = features[:, :num_subframes, :] 137*a58d3d2aSXin Li 138*a58d3d2aSXin Li#variable quantizer depending on the delay 139*a58d3d2aSXin Liq0 = 3 140*a58d3d2aSXin Liq1 = 15 141*a58d3d2aSXin Liquant_id = np.round(q1 + (q0-q1)*np.arange(args.num_redundancy_frames//2)/args.num_redundancy_frames).astype('int16') 142*a58d3d2aSXin Li#print(quant_id) 143*a58d3d2aSXin Li 144*a58d3d2aSXin Liquant_embed = qembedding(quant_id) 145*a58d3d2aSXin Li 146*a58d3d2aSXin Li# run encoder 147*a58d3d2aSXin Liprint("running fec encoder...") 148*a58d3d2aSXin Lisymbols, gru_state_dec = encoder.predict(features) 149*a58d3d2aSXin Li 150*a58d3d2aSXin Li# apply quantization 151*a58d3d2aSXin Linsymbols = 80 152*a58d3d2aSXin Liquant_scale = tf.math.softplus(quant_embed[:, :nsymbols]).numpy() 153*a58d3d2aSXin Lidead_zone = tf.math.softplus(quant_embed[:, nsymbols : 2 * nsymbols]).numpy() 154*a58d3d2aSXin Li#symbols = apply_dead_zone([symbols, dead_zone]).numpy() 155*a58d3d2aSXin Li#qsymbols = np.round(symbols) 156*a58d3d2aSXin Liquant_gru_state_dec = pvq_quantize(gru_state_dec, 82) 157*a58d3d2aSXin Li 158*a58d3d2aSXin Li# rate estimate 159*a58d3d2aSXin Lihard_distr_embed = tf.math.sigmoid(quant_embed[:, 4 * nsymbols : ]).numpy() 160*a58d3d2aSXin Li#rate_input = np.concatenate((qsymbols, hard_distr_embed, enc_lambda), axis=-1) 161*a58d3d2aSXin Li#rates = sq_rate_metric(None, rate_input, reduce=False).numpy() 162*a58d3d2aSXin Li 163*a58d3d2aSXin Li# run decoder 164*a58d3d2aSXin Liinput_length = args.num_redundancy_frames // 2 165*a58d3d2aSXin Lioffset = args.num_redundancy_frames - 1 166*a58d3d2aSXin Li 167*a58d3d2aSXin Lipackets = [] 168*a58d3d2aSXin Lipacket_sizes = [] 169*a58d3d2aSXin Li 170*a58d3d2aSXin Lisym_batch = np.zeros((num_frames-offset, args.num_redundancy_frames//2, nsymbols), dtype='float32') 171*a58d3d2aSXin Liquant_state = quant_gru_state_dec[0, offset:num_frames, :] 172*a58d3d2aSXin Li#pack symbols for batch processing 173*a58d3d2aSXin Lifor i in range(offset, num_frames): 174*a58d3d2aSXin Li sym_batch[i-offset, :, :] = symbols[0, i - 2 * input_length + 2 : i + 1 : 2, :] 175*a58d3d2aSXin Li 176*a58d3d2aSXin Li#quantize symbols 177*a58d3d2aSXin Lisym_batch = sym_batch * quant_scale 178*a58d3d2aSXin Lisym_batch = apply_dead_zone([sym_batch, dead_zone]).numpy() 179*a58d3d2aSXin Lisym_batch = np.round(sym_batch) 180*a58d3d2aSXin Li 181*a58d3d2aSXin Lihard_distr_embed = np.broadcast_to(hard_distr_embed, (sym_batch.shape[0], sym_batch.shape[1], 2*sym_batch.shape[2])) 182*a58d3d2aSXin Lifake_lambda = np.ones((sym_batch.shape[0], sym_batch.shape[1], 1), dtype='float32') 183*a58d3d2aSXin Lirate_input = np.concatenate((sym_batch, hard_distr_embed, fake_lambda), axis=-1) 184*a58d3d2aSXin Lirates = sq_rate_metric(None, rate_input, reduce=False).numpy() 185*a58d3d2aSXin Li#print(rates.shape) 186*a58d3d2aSXin Liprint("average rate = ", np.mean(rates[args.num_redundancy_frames:,:])) 187*a58d3d2aSXin Li 188*a58d3d2aSXin Li#sym_batch.tofile('qsyms.f32') 189*a58d3d2aSXin Li 190*a58d3d2aSXin Lisym_batch = sym_batch / quant_scale 191*a58d3d2aSXin Li#print(sym_batch.shape, quant_state.shape) 192*a58d3d2aSXin Li#features = decoder.predict([sym_batch, quant_state]) 193*a58d3d2aSXin Lifeatures = decoder([sym_batch, quant_state]) 194*a58d3d2aSXin Li 195*a58d3d2aSXin Li#for i in range(offset, num_frames): 196*a58d3d2aSXin Li# print(f"processing frame {i - offset}...") 197*a58d3d2aSXin Li# features = decoder.predict([qsymbols[:, i - 2 * input_length + 2 : i + 1 : 2, :], quant_embed_dec[:, i - 2 * input_length + 2 : i + 1 : 2, :], quant_gru_state_dec[:, i, :]]) 198*a58d3d2aSXin Li# packets.append(features) 199*a58d3d2aSXin Li# packet_size = 8 * int((np.sum(rates[:, i - 2 * input_length + 2 : i + 1 : 2]) + 7) / 8) + 64 200*a58d3d2aSXin Li# packet_sizes.append(packet_size) 201*a58d3d2aSXin Li 202*a58d3d2aSXin Li 203*a58d3d2aSXin Li# write packets 204*a58d3d2aSXin Lipacket_file = args.output + '.fec' if not args.output.endswith('.fec') else args.output 205*a58d3d2aSXin Li#write_fec_packets(packet_file, packets, packet_sizes) 206*a58d3d2aSXin Li 207*a58d3d2aSXin Li 208*a58d3d2aSXin Li#print(f"average redundancy rate: {int(round(sum(packet_sizes) / len(packet_sizes) * 50 / 1000))} kbps") 209*a58d3d2aSXin Li 210*a58d3d2aSXin Liif args.lossfile != None: 211*a58d3d2aSXin Li loss = np.loadtxt(args.lossfile, dtype='int16') 212*a58d3d2aSXin Li fec_out = np.zeros((features.shape[0]*2, features.shape[-1]), dtype='float32') 213*a58d3d2aSXin Li foffset = -2 214*a58d3d2aSXin Li ptr = 0; 215*a58d3d2aSXin Li count = 2; 216*a58d3d2aSXin Li for i in range(features.shape[0]): 217*a58d3d2aSXin Li if (loss[i] == 0) or (i == features.shape[0]-1): 218*a58d3d2aSXin Li fec_out[ptr:ptr+count,:] = features[i, foffset:, :] 219*a58d3d2aSXin Li #print("filled ", count) 220*a58d3d2aSXin Li foffset = -2 221*a58d3d2aSXin Li ptr = ptr+count 222*a58d3d2aSXin Li count = 2 223*a58d3d2aSXin Li else: 224*a58d3d2aSXin Li count = count + 2 225*a58d3d2aSXin Li foffset = foffset - 2 226*a58d3d2aSXin Li 227*a58d3d2aSXin Li fec_out_full = np.zeros((fec_out.shape[0], nb_features), dtype=np.float32) 228*a58d3d2aSXin Li fec_out_full[:, :nb_used_features] = fec_out 229*a58d3d2aSXin Li 230*a58d3d2aSXin Li fec_out_full.tofile(packet_file[:-4] + f'_fec.f32') 231*a58d3d2aSXin Li 232*a58d3d2aSXin Li 233*a58d3d2aSXin Li#create packets array like in the original version for debugging purposes 234*a58d3d2aSXin Lifor i in range(offset, num_frames): 235*a58d3d2aSXin Li packets.append(features[i-offset:i-offset+1, :, :]) 236*a58d3d2aSXin Li 237*a58d3d2aSXin Liif args.debug_output: 238*a58d3d2aSXin Li import itertools 239*a58d3d2aSXin Li 240*a58d3d2aSXin Li #batches = [2, 4] 241*a58d3d2aSXin Li batches = [4] 242*a58d3d2aSXin Li #offsets = [0, 4, 20] 243*a58d3d2aSXin Li offsets = [0, (args.num_redundancy_frames - 2)*2] 244*a58d3d2aSXin Li # sanity checks 245*a58d3d2aSXin Li # 1. concatenate features at offset 0 246*a58d3d2aSXin Li for batch, offset in itertools.product(batches, offsets): 247*a58d3d2aSXin Li 248*a58d3d2aSXin Li stop = packets[0].shape[1] - offset 249*a58d3d2aSXin Li print(batch, offset, stop) 250*a58d3d2aSXin Li test_features = np.concatenate([packet[:,stop - batch: stop, :] for packet in packets[::batch//2]], axis=1) 251*a58d3d2aSXin Li 252*a58d3d2aSXin Li test_features_full = np.zeros((test_features.shape[1], nb_features), dtype=np.float32) 253*a58d3d2aSXin Li test_features_full[:, :nb_used_features] = test_features[0, :, :] 254*a58d3d2aSXin Li 255*a58d3d2aSXin Li print(f"writing debug output {packet_file[:-4] + f'_tf_batch{batch}_offset{offset}.f32'}") 256*a58d3d2aSXin Li test_features_full.tofile(packet_file[:-4] + f'_tf_batch{batch}_offset{offset}.f32') 257