1""" 2/* Copyright (c) 2022 Amazon 3 Written by Jan Buethe and Jean-Marc Valin */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import os 31import subprocess 32import argparse 33 34os.environ['CUDA_VISIBLE_DEVICES'] = "" 35 36parser = argparse.ArgumentParser(description='Encode redundancy for Opus neural FEC. Designed for use with voip application and 20ms frames') 37 38parser.add_argument('input', metavar='<input signal>', help='audio input (.wav or .raw or .pcm as int16)') 39parser.add_argument('checkpoint', metavar='<weights>', help='model checkpoint') 40parser.add_argument('q0', metavar='<quant level 0>', type=int, help='quantization level for most recent frame') 41parser.add_argument('q1', metavar='<quant level 1>', type=int, help='quantization level for oldest frame') 42parser.add_argument('output', type=str, help='output file (will be extended with .fec)') 43 44parser.add_argument('--dump-data', type=str, default='./dump_data', help='path to dump data executable (default ./dump_data)') 45parser.add_argument('--num-redundancy-frames', default=52, type=int, help='number of redundancy frames per packet (default 52)') 46parser.add_argument('--extra-delay', default=0, type=int, help="last features in packet are calculated with the decoder aligned samples, use this option to add extra delay (in samples at 16kHz)") 47parser.add_argument('--lossfile', type=str, help='file containing loss trace (0 for frame received, 1 for lost)') 48parser.add_argument('--debug-output', action='store_true', help='if set, differently assembled features are written to disk') 49 50args = parser.parse_args() 51 52import numpy as np 53from scipy.io import wavfile 54import torch 55 56from rdovae import RDOVAE 57from packets import write_fec_packets 58 59torch.set_num_threads(4) 60 61checkpoint = torch.load(args.checkpoint, map_location="cpu") 62model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs']) 63model.load_state_dict(checkpoint['state_dict'], strict=False) 64model.to("cpu") 65 66lpc_order = 16 67 68## prepare input signal 69# SILK frame size is 20ms and LPCNet subframes are 10ms 70subframe_size = 160 71frame_size = 2 * subframe_size 72 73# 91 samples delay to align with SILK decoded frames 74silk_delay = 91 75 76# prepend zeros to have enough history to produce the first package 77zero_history = (args.num_redundancy_frames - 1) * frame_size 78 79# dump data has a (feature) delay of 10ms 80dump_data_delay = 160 81 82total_delay = silk_delay + zero_history + args.extra_delay - dump_data_delay 83 84# load signal 85if args.input.endswith('.raw') or args.input.endswith('.pcm'): 86 signal = np.fromfile(args.input, dtype='int16') 87 88elif args.input.endswith('.wav'): 89 fs, signal = wavfile.read(args.input) 90else: 91 raise ValueError(f'unknown input signal format: {args.input}') 92 93# fill up last frame with zeros 94padded_signal_length = len(signal) + total_delay 95tail = padded_signal_length % frame_size 96right_padding = (frame_size - tail) % frame_size 97 98signal = np.concatenate((np.zeros(total_delay, dtype=np.int16), signal, np.zeros(right_padding, dtype=np.int16))) 99 100padded_signal_file = os.path.splitext(args.input)[0] + '_padded.raw' 101signal.tofile(padded_signal_file) 102 103# write signal and call dump_data to create features 104 105feature_file = os.path.splitext(args.input)[0] + '_features.f32' 106command = f"{args.dump_data} -test {padded_signal_file} {feature_file}" 107r = subprocess.run(command, shell=True) 108if r.returncode != 0: 109 raise RuntimeError(f"command '{command}' failed with exit code {r.returncode}") 110 111# load features 112nb_features = model.feature_dim + lpc_order 113nb_used_features = model.feature_dim 114 115# load features 116features = np.fromfile(feature_file, dtype='float32') 117num_subframes = len(features) // nb_features 118num_subframes = 2 * (num_subframes // 2) 119num_frames = num_subframes // 2 120 121features = np.reshape(features, (1, -1, nb_features)) 122features = features[:, :, :nb_used_features] 123features = features[:, :num_subframes, :] 124 125# quant_ids in reverse decoding order 126quant_ids = torch.round((args.q1 + (args.q0 - args.q1) * torch.arange(args.num_redundancy_frames // 2) / (args.num_redundancy_frames // 2 - 1))).long() 127 128print(f"using quantization levels {quant_ids}...") 129 130# convert input to torch tensors 131features = torch.from_numpy(features) 132 133 134# run encoder 135print("running fec encoder...") 136with torch.no_grad(): 137 138 # encoding 139 z, states, state_size = model.encode(features) 140 141 142 # decoder on packet chunks 143 input_length = args.num_redundancy_frames // 2 144 offset = args.num_redundancy_frames - 1 145 146 packets = [] 147 packet_sizes = [] 148 149 for i in range(offset, num_frames): 150 print(f"processing frame {i - offset}...") 151 # quantize / unquantize latent vectors 152 zi = torch.clone(z[:, i - 2 * input_length + 2: i + 1 : 2, :]) 153 zi, rates = model.quantize(zi, quant_ids) 154 zi = model.unquantize(zi, quant_ids) 155 156 features = model.decode(zi, states[:, i : i + 1, :]) 157 packets.append(features.squeeze(0).numpy()) 158 packet_size = 8 * int((torch.sum(rates) + 7 + state_size) / 8) 159 packet_sizes.append(packet_size) 160 161 162# write packets 163packet_file = args.output + '.fec' if not args.output.endswith('.fec') else args.output 164write_fec_packets(packet_file, packets, packet_sizes) 165 166 167print(f"average redundancy rate: {int(round(sum(packet_sizes) / len(packet_sizes) * 50 / 1000))} kbps") 168 169# assemble features according to loss file 170if args.lossfile != None: 171 num_packets = len(packets) 172 loss = np.loadtxt(args.lossfile, dtype='int16') 173 fec_out = np.zeros((num_packets * 2, packets[0].shape[-1]), dtype='float32') 174 foffset = -2 175 ptr = 0 176 count = 2 177 for i in range(num_packets): 178 if (loss[i] == 0) or (i == num_packets - 1): 179 180 fec_out[ptr:ptr+count,:] = packets[i][foffset:, :] 181 182 ptr += count 183 foffset = -2 184 count = 2 185 else: 186 count += 2 187 foffset -= 2 188 189 fec_out_full = np.zeros((fec_out.shape[0], 36), dtype=np.float32) 190 fec_out_full[:, : fec_out.shape[-1]] = fec_out 191 192 fec_out_full.tofile(packet_file[:-4] + f'_fec.f32') 193 194 195if args.debug_output: 196 import itertools 197 198 batches = [4] 199 offsets = [0, 2 * args.num_redundancy_frames - 4] 200 201 # sanity checks 202 # 1. concatenate features at offset 0 203 for batch, offset in itertools.product(batches, offsets): 204 205 stop = packets[0].shape[1] - offset 206 test_features = np.concatenate([packet[stop - batch: stop, :] for packet in packets[::batch//2]], axis=0) 207 208 test_features_full = np.zeros((test_features.shape[0], nb_features), dtype=np.float32) 209 test_features_full[:, :nb_used_features] = test_features[:, :] 210 211 print(f"writing debug output {packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32'}") 212 test_features_full.tofile(packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32') 213