xref: /aosp_15_r20/external/libopus/dnn/torch/rdovae/fec_encoder.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2022 Amazon
3   Written by Jan Buethe and Jean-Marc Valin */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import os
31import subprocess
32import argparse
33
34os.environ['CUDA_VISIBLE_DEVICES'] = ""
35
36parser = argparse.ArgumentParser(description='Encode redundancy for Opus neural FEC. Designed for use with voip application and 20ms frames')
37
38parser.add_argument('input', metavar='<input signal>', help='audio input (.wav or .raw or .pcm as int16)')
39parser.add_argument('checkpoint', metavar='<weights>', help='model checkpoint')
40parser.add_argument('q0', metavar='<quant level 0>', type=int, help='quantization level for most recent frame')
41parser.add_argument('q1', metavar='<quant level 1>', type=int, help='quantization level for oldest frame')
42parser.add_argument('output', type=str, help='output file (will be extended with .fec)')
43
44parser.add_argument('--dump-data', type=str, default='./dump_data', help='path to dump data executable (default ./dump_data)')
45parser.add_argument('--num-redundancy-frames', default=52, type=int, help='number of redundancy frames per packet (default 52)')
46parser.add_argument('--extra-delay', default=0, type=int, help="last features in packet are calculated with the decoder aligned samples, use this option to add extra delay (in samples at 16kHz)")
47parser.add_argument('--lossfile', type=str, help='file containing loss trace (0 for frame received, 1 for lost)')
48parser.add_argument('--debug-output', action='store_true', help='if set, differently assembled features are written to disk')
49
50args = parser.parse_args()
51
52import numpy as np
53from scipy.io import wavfile
54import torch
55
56from rdovae import RDOVAE
57from packets import write_fec_packets
58
59torch.set_num_threads(4)
60
61checkpoint = torch.load(args.checkpoint, map_location="cpu")
62model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
63model.load_state_dict(checkpoint['state_dict'], strict=False)
64model.to("cpu")
65
66lpc_order = 16
67
68## prepare input signal
69# SILK frame size is 20ms and LPCNet subframes are 10ms
70subframe_size = 160
71frame_size = 2 * subframe_size
72
73# 91 samples delay to align with SILK decoded frames
74silk_delay = 91
75
76# prepend zeros to have enough history to produce the first package
77zero_history = (args.num_redundancy_frames - 1) * frame_size
78
79# dump data has a (feature) delay of 10ms
80dump_data_delay = 160
81
82total_delay = silk_delay + zero_history + args.extra_delay - dump_data_delay
83
84# load signal
85if args.input.endswith('.raw') or args.input.endswith('.pcm'):
86    signal = np.fromfile(args.input, dtype='int16')
87
88elif args.input.endswith('.wav'):
89    fs, signal = wavfile.read(args.input)
90else:
91    raise ValueError(f'unknown input signal format: {args.input}')
92
93# fill up last frame with zeros
94padded_signal_length = len(signal) + total_delay
95tail = padded_signal_length % frame_size
96right_padding = (frame_size - tail) % frame_size
97
98signal = np.concatenate((np.zeros(total_delay, dtype=np.int16), signal, np.zeros(right_padding, dtype=np.int16)))
99
100padded_signal_file  = os.path.splitext(args.input)[0] + '_padded.raw'
101signal.tofile(padded_signal_file)
102
103# write signal and call dump_data to create features
104
105feature_file = os.path.splitext(args.input)[0] + '_features.f32'
106command = f"{args.dump_data} -test {padded_signal_file} {feature_file}"
107r = subprocess.run(command, shell=True)
108if r.returncode != 0:
109    raise RuntimeError(f"command '{command}' failed with exit code {r.returncode}")
110
111# load features
112nb_features = model.feature_dim + lpc_order
113nb_used_features = model.feature_dim
114
115# load features
116features = np.fromfile(feature_file, dtype='float32')
117num_subframes = len(features) // nb_features
118num_subframes = 2 * (num_subframes // 2)
119num_frames = num_subframes // 2
120
121features = np.reshape(features, (1, -1, nb_features))
122features = features[:, :, :nb_used_features]
123features = features[:, :num_subframes, :]
124
125# quant_ids in reverse decoding order
126quant_ids = torch.round((args.q1 + (args.q0 - args.q1) * torch.arange(args.num_redundancy_frames // 2) / (args.num_redundancy_frames // 2 - 1))).long()
127
128print(f"using quantization levels {quant_ids}...")
129
130# convert input to torch tensors
131features = torch.from_numpy(features)
132
133
134# run encoder
135print("running fec encoder...")
136with torch.no_grad():
137
138    # encoding
139    z, states, state_size = model.encode(features)
140
141
142    # decoder on packet chunks
143    input_length = args.num_redundancy_frames // 2
144    offset = args.num_redundancy_frames - 1
145
146    packets = []
147    packet_sizes = []
148
149    for i in range(offset, num_frames):
150        print(f"processing frame {i - offset}...")
151        # quantize / unquantize latent vectors
152        zi = torch.clone(z[:, i - 2 * input_length + 2: i + 1 : 2, :])
153        zi, rates = model.quantize(zi, quant_ids)
154        zi = model.unquantize(zi, quant_ids)
155
156        features = model.decode(zi, states[:, i : i + 1, :])
157        packets.append(features.squeeze(0).numpy())
158        packet_size = 8 * int((torch.sum(rates) + 7 + state_size) / 8)
159        packet_sizes.append(packet_size)
160
161
162# write packets
163packet_file = args.output + '.fec' if not args.output.endswith('.fec') else args.output
164write_fec_packets(packet_file, packets, packet_sizes)
165
166
167print(f"average redundancy rate: {int(round(sum(packet_sizes) / len(packet_sizes) * 50 / 1000))} kbps")
168
169# assemble features according to loss file
170if args.lossfile != None:
171    num_packets = len(packets)
172    loss = np.loadtxt(args.lossfile, dtype='int16')
173    fec_out = np.zeros((num_packets * 2, packets[0].shape[-1]), dtype='float32')
174    foffset = -2
175    ptr = 0
176    count = 2
177    for i in range(num_packets):
178        if (loss[i] == 0) or (i == num_packets - 1):
179
180            fec_out[ptr:ptr+count,:] = packets[i][foffset:, :]
181
182            ptr    += count
183            foffset = -2
184            count   = 2
185        else:
186            count   += 2
187            foffset -= 2
188
189    fec_out_full = np.zeros((fec_out.shape[0], 36), dtype=np.float32)
190    fec_out_full[:, : fec_out.shape[-1]] = fec_out
191
192    fec_out_full.tofile(packet_file[:-4] + f'_fec.f32')
193
194
195if args.debug_output:
196    import itertools
197
198    batches = [4]
199    offsets = [0, 2 * args.num_redundancy_frames - 4]
200
201    # sanity checks
202    # 1. concatenate features at offset 0
203    for batch, offset in itertools.product(batches, offsets):
204
205        stop = packets[0].shape[1] - offset
206        test_features = np.concatenate([packet[stop - batch: stop, :] for packet in packets[::batch//2]], axis=0)
207
208        test_features_full = np.zeros((test_features.shape[0], nb_features), dtype=np.float32)
209        test_features_full[:, :nb_used_features] = test_features[:, :]
210
211        print(f"writing debug output {packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32'}")
212        test_features_full.tofile(packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32')
213