xref: /aosp_15_r20/external/libopus/dnn/torch/testsuite/run_test.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import os
31import multiprocess as multiprocessing
32import random
33import subprocess
34import argparse
35import shutil
36
37import yaml
38
39from utils.files import get_wave_file_list
40from utils.pesq import compute_PESQ
41from utils.pitch import compute_pitch_error
42
43
44parser = argparse.ArgumentParser()
45parser.add_argument('setup', type=str, help='setup yaml specifying end to end processing with model under test')
46parser.add_argument('input_folder', type=str, help='input folder path')
47parser.add_argument('output_folder', type=str, help='output folder path')
48parser.add_argument('--num-testitems', type=int, help="number of testitems to be processed (default 100)", default=100)
49parser.add_argument('--seed', type=int, help='seed for random item selection', default=None)
50parser.add_argument('--fs', type=int, help="sampling rate at which input is presented as wave file (defaults to 16000)", default=16000)
51parser.add_argument('--num-workers', type=int, help="number of subprocesses to be used (default=4)", default=4)
52parser.add_argument('--plc-suffix', type=str, default="_is_lost.txt", help="suffix of plc error pattern file: only relevant if command chain uses PLCFILE (default=_is_lost.txt)")
53parser.add_argument('--metrics', type=str, default='pesq', help='comma separated string of metrics, supported: {{"pesq", "pitch_error", "voicing_error"}}, default="pesq"')
54parser.add_argument('--verbose', action='store_true', help='enables printouts of all commands run in the pipeline')
55
56def check_for_sox_in_path():
57    r = subprocess.run("sox -h", shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
58    return r.returncode == 0
59
60
61def run_save_sh(command, verbose=False):
62
63    if verbose:
64        print(f"[run_save_sh] running command {command}...")
65
66    r = subprocess.run(command, shell=True)
67    if r.returncode != 0:
68        raise RuntimeError(f"command '{command}' failed with exit code {r.returncode}")
69
70
71def run_processing_chain(input_path, output_path, model_commands, fs, metrics={'pesq'}, plc_suffix="_is_lost.txt", verbose=False):
72
73    # prepare model input
74    model_input = output_path + ".resamp.wav"
75    run_save_sh(f"sox {input_path} -r {fs} {model_input}", verbose=verbose)
76
77    plcfile = os.path.splitext(input_path)[0] + plc_suffix
78    if os.path.isfile(plcfile):
79        run_save_sh(f"cp {plcfile} {os.path.dirname(output_path)}")
80
81    # generate model output
82    for command in model_commands:
83        run_save_sh(command.format(INPUT=model_input, OUTPUT=output_path, PLCFILE=plcfile), verbose=verbose)
84
85    scores = dict()
86    cache = dict()
87    for metric in metrics:
88        if metric == 'pesq':
89            # run pesq
90            score = compute_PESQ(input_path, output_path, fs=fs)
91        elif metric == 'pitch_error':
92            if metric in cache:
93                score = cache[metric]
94            else:
95                rval = compute_pitch_error(input_path, output_path, fs=fs)
96                score = rval[metric]
97                cache['voicing_error'] = rval['voicing_error']
98        elif metric == 'voicing_error':
99            if metric in cache:
100                score = cache[metric]
101            else:
102                rval = compute_pitch_error(input_path, output_path, fs=fs)
103                score = rval[metric]
104                cache['pitch_error'] = rval['pitch_error']
105        else:
106            ValueError(f'error: unknown metric {metric}')
107
108        scores[metric] = score
109
110    return (output_path, scores)
111
112
113def get_output_path(root_folder, input, output_folder):
114
115    input_relpath = os.path.relpath(input, root_folder)
116
117    os.makedirs(os.path.join(output_folder, 'processing', os.path.dirname(input_relpath)), exist_ok=True)
118
119    output_path = os.path.join(output_folder, 'processing', input_relpath + '.output.wav')
120
121    return output_path
122
123
124def add_audio_table(f, html_folder, results, title, metric):
125
126    item_folder = os.path.join(html_folder, 'items')
127    os.makedirs(item_folder, exist_ok=True)
128
129    # table with results
130    f.write(f"""
131            <div>
132            <h2> {title} </h2>
133            <table>
134            <tr>
135                <th> Rank   </th>
136                <th> Name   </th>
137                <th> {metric.upper()} </th>
138                <th> Audio (out)  </th>
139                <th> Audio (orig)  </th>
140            </tr>
141            """)
142
143    for i, r in enumerate(results):
144        item, score = r
145        item_name = os.path.basename(item)
146        new_item_path = os.path.join(item_folder, item_name)
147        shutil.copyfile(item, new_item_path)
148        shutil.copyfile(item + '.resamp.wav', os.path.join(item_folder, item_name + '.orig.wav'))
149
150        f.write(f"""
151                <tr>
152                    <td> {i + 1} </td>
153                    <td> {item_name.split('.')[0]} </td>
154                    <td> {score:.3f} </td>
155                    <td>
156                        <audio controls>
157                            <source src="items/{item_name}">
158                        </audio>
159                    </td>
160                    <td>
161                        <audio controls>
162                            <source src="items/{item_name + '.orig.wav'}">
163                        </audio>
164                    </td>
165                </tr>
166                """)
167
168    # footer
169    f.write("""
170            </table>
171            </div>
172            """)
173
174
175def create_html(output_folder, results, title, metric):
176
177    html_folder = output_folder
178    items_folder = os.path.join(html_folder, 'items')
179    os.makedirs(html_folder, exist_ok=True)
180    os.makedirs(items_folder, exist_ok=True)
181
182    with open(os.path.join(html_folder, 'index.html'), 'w') as f:
183        # header and title
184        f.write(f"""
185                <!DOCTYPE html>
186                <html lang="en">
187                <head>
188                    <meta charset="utf-8">
189                    <title>{title}</title>
190                    <style>
191                        article {{
192                            align-items: flex-start;
193                            display: flex;
194                            flex-wrap: wrap;
195                            gap: 4em;
196                        }}
197                        html {{
198                            box-sizing: border-box;
199                            font-family: "Amazon Ember", "Source Sans", "Verdana", "Calibri", sans-serif;
200                            padding: 2em;
201                        }}
202                        td {{
203                            padding: 3px 7px;
204                            text-align: center;
205                        }}
206                        td:first-child {{
207                            text-align: end;
208                        }}
209                        th {{
210                            background: #ff9900;
211                            color: #000;
212                            font-size: 1.2em;
213                            padding: 7px 7px;
214                        }}
215                    </style>
216                </head>
217                </body>
218                <h1>{title}</h1>
219                <article>
220                """)
221
222        # top 20
223        add_audio_table(f, html_folder, results[:-21: -1], "Top 20", metric)
224
225        # 20 around median
226        N = len(results) // 2
227        add_audio_table(f, html_folder, results[N + 10 : N - 10: -1], "Median 20", metric)
228
229        # flop 20
230        add_audio_table(f, html_folder, results[:20], "Flop 20", metric)
231
232        # footer
233        f.write("""
234                </article>
235                </body>
236                </html>
237                """)
238
239metric_sorting_signs = {
240    'pesq'          : 1,
241    'pitch_error'   : -1,
242    'voicing_error' : -1
243}
244
245def is_valid_result(data, metrics):
246    if not isinstance(data, dict):
247        return False
248
249    for metric in metrics:
250        if not metric in data:
251            return False
252
253    return True
254
255
256def evaluate_results(output_folder, results, metric):
257
258    results = sorted(results, key=lambda x : metric_sorting_signs[metric] * x[1])
259    with open(os.path.join(args.output_folder, f'scores_{metric}.txt'), 'w') as f:
260        for result in results:
261            f.write(f"{os.path.relpath(result[0], args.output_folder)} {result[1]}\n")
262
263
264    # some statistics
265    mean = sum([r[1] for r in results]) / len(results)
266    top_mean = sum([r[1] for r in results[-20:]]) / 20
267    bottom_mean = sum([r[1] for r in results[:20]]) / 20
268
269    with open(os.path.join(args.output_folder, f'stats_{metric}.txt'), 'w') as f:
270        f.write(f"mean score: {mean}\n")
271        f.write(f"bottom mean score: {bottom_mean}\n")
272        f.write(f"top mean score: {top_mean}\n")
273
274    print(f"\nmean score: {mean}")
275    print(f"bottom mean score: {bottom_mean}")
276    print(f"top mean score: {top_mean}\n")
277
278    # create output html
279    create_html(os.path.join(output_folder, 'html', metric), results, setup['test'], metric)
280
281if __name__ == "__main__":
282    args = parser.parse_args()
283
284    # check for sox
285    if not check_for_sox_in_path():
286        raise RuntimeError("script requires sox")
287
288
289    # prepare output folder
290    if os.path.exists(args.output_folder):
291        print("warning: output folder exists")
292
293        reply = input('continue? (y/n): ')
294        while reply not in {'y', 'n'}:
295            reply = input('continue? (y/n): ')
296
297        if reply == 'n':
298            os._exit()
299        else:
300            # start with a clean sleight
301            shutil.rmtree(args.output_folder)
302
303    os.makedirs(args.output_folder, exist_ok=True)
304
305    # extract metrics
306    metrics = args.metrics.split(",")
307    for metric in metrics:
308        if not metric in metric_sorting_signs:
309            print(f"unknown metric {metric}")
310            args.usage()
311
312    # read setup
313    print(f"loading {args.setup}...")
314    with open(args.setup, "r") as f:
315        setup = yaml.load(f.read(), yaml.FullLoader)
316
317    model_commands = setup['processing']
318
319    print("\nfound the following model commands:")
320    for command in model_commands:
321        print(command.format(INPUT='input.wav', OUTPUT='output.wav', PLCFILE='input_is_lost.txt'))
322
323    # store setup to output folder
324    setup['input']  = os.path.abspath(args.input_folder)
325    setup['output'] = os.path.abspath(args.output_folder)
326    setup['seed']   = args.seed
327    with open(os.path.join(args.output_folder, 'setup.yml'), 'w') as f:
328        yaml.dump(setup, f)
329
330    # get input
331    print(f"\nCollecting audio files from {args.input_folder}...")
332    file_list = get_wave_file_list(args.input_folder, check_for_features=False)
333    print(f"...{len(file_list)} files found\n")
334
335    # sample from file list
336    file_list = sorted(file_list)
337    random.seed(args.seed)
338    random.shuffle(file_list)
339    num_testitems = min(args.num_testitems, len(file_list))
340    file_list = file_list[:num_testitems]
341
342
343    print(f"\nlaunching test on {num_testitems} items...")
344    # helper function for parallel processing
345    def func(input_path):
346        output_path = get_output_path(args.input_folder, input_path, args.output_folder)
347
348        try:
349            rval = run_processing_chain(input_path, output_path, model_commands, args.fs, metrics=metrics, plc_suffix=args.plc_suffix, verbose=args.verbose)
350        except:
351            rval = (input_path, -1)
352
353        return rval
354
355    with multiprocessing.Pool(args.num_workers) as p:
356        results = p.map(func, file_list)
357
358    results_dict = dict()
359    for name, values in results:
360        if is_valid_result(values, metrics):
361            results_dict[name] = values
362
363    print(results_dict)
364
365    # evaluating results
366    num_failures = num_testitems - len(results_dict)
367    print(f"\nprocessing of {num_failures} items failed\n")
368
369    for metric in metrics:
370        print(metric)
371        evaluate_results(
372            args.output_folder,
373            [(name, value[metric]) for name, value in results_dict.items()],
374            metric
375        )