1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import os 31import multiprocess as multiprocessing 32import random 33import subprocess 34import argparse 35import shutil 36 37import yaml 38 39from utils.files import get_wave_file_list 40from utils.pesq import compute_PESQ 41from utils.pitch import compute_pitch_error 42 43 44parser = argparse.ArgumentParser() 45parser.add_argument('setup', type=str, help='setup yaml specifying end to end processing with model under test') 46parser.add_argument('input_folder', type=str, help='input folder path') 47parser.add_argument('output_folder', type=str, help='output folder path') 48parser.add_argument('--num-testitems', type=int, help="number of testitems to be processed (default 100)", default=100) 49parser.add_argument('--seed', type=int, help='seed for random item selection', default=None) 50parser.add_argument('--fs', type=int, help="sampling rate at which input is presented as wave file (defaults to 16000)", default=16000) 51parser.add_argument('--num-workers', type=int, help="number of subprocesses to be used (default=4)", default=4) 52parser.add_argument('--plc-suffix', type=str, default="_is_lost.txt", help="suffix of plc error pattern file: only relevant if command chain uses PLCFILE (default=_is_lost.txt)") 53parser.add_argument('--metrics', type=str, default='pesq', help='comma separated string of metrics, supported: {{"pesq", "pitch_error", "voicing_error"}}, default="pesq"') 54parser.add_argument('--verbose', action='store_true', help='enables printouts of all commands run in the pipeline') 55 56def check_for_sox_in_path(): 57 r = subprocess.run("sox -h", shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) 58 return r.returncode == 0 59 60 61def run_save_sh(command, verbose=False): 62 63 if verbose: 64 print(f"[run_save_sh] running command {command}...") 65 66 r = subprocess.run(command, shell=True) 67 if r.returncode != 0: 68 raise RuntimeError(f"command '{command}' failed with exit code {r.returncode}") 69 70 71def run_processing_chain(input_path, output_path, model_commands, fs, metrics={'pesq'}, plc_suffix="_is_lost.txt", verbose=False): 72 73 # prepare model input 74 model_input = output_path + ".resamp.wav" 75 run_save_sh(f"sox {input_path} -r {fs} {model_input}", verbose=verbose) 76 77 plcfile = os.path.splitext(input_path)[0] + plc_suffix 78 if os.path.isfile(plcfile): 79 run_save_sh(f"cp {plcfile} {os.path.dirname(output_path)}") 80 81 # generate model output 82 for command in model_commands: 83 run_save_sh(command.format(INPUT=model_input, OUTPUT=output_path, PLCFILE=plcfile), verbose=verbose) 84 85 scores = dict() 86 cache = dict() 87 for metric in metrics: 88 if metric == 'pesq': 89 # run pesq 90 score = compute_PESQ(input_path, output_path, fs=fs) 91 elif metric == 'pitch_error': 92 if metric in cache: 93 score = cache[metric] 94 else: 95 rval = compute_pitch_error(input_path, output_path, fs=fs) 96 score = rval[metric] 97 cache['voicing_error'] = rval['voicing_error'] 98 elif metric == 'voicing_error': 99 if metric in cache: 100 score = cache[metric] 101 else: 102 rval = compute_pitch_error(input_path, output_path, fs=fs) 103 score = rval[metric] 104 cache['pitch_error'] = rval['pitch_error'] 105 else: 106 ValueError(f'error: unknown metric {metric}') 107 108 scores[metric] = score 109 110 return (output_path, scores) 111 112 113def get_output_path(root_folder, input, output_folder): 114 115 input_relpath = os.path.relpath(input, root_folder) 116 117 os.makedirs(os.path.join(output_folder, 'processing', os.path.dirname(input_relpath)), exist_ok=True) 118 119 output_path = os.path.join(output_folder, 'processing', input_relpath + '.output.wav') 120 121 return output_path 122 123 124def add_audio_table(f, html_folder, results, title, metric): 125 126 item_folder = os.path.join(html_folder, 'items') 127 os.makedirs(item_folder, exist_ok=True) 128 129 # table with results 130 f.write(f""" 131 <div> 132 <h2> {title} </h2> 133 <table> 134 <tr> 135 <th> Rank </th> 136 <th> Name </th> 137 <th> {metric.upper()} </th> 138 <th> Audio (out) </th> 139 <th> Audio (orig) </th> 140 </tr> 141 """) 142 143 for i, r in enumerate(results): 144 item, score = r 145 item_name = os.path.basename(item) 146 new_item_path = os.path.join(item_folder, item_name) 147 shutil.copyfile(item, new_item_path) 148 shutil.copyfile(item + '.resamp.wav', os.path.join(item_folder, item_name + '.orig.wav')) 149 150 f.write(f""" 151 <tr> 152 <td> {i + 1} </td> 153 <td> {item_name.split('.')[0]} </td> 154 <td> {score:.3f} </td> 155 <td> 156 <audio controls> 157 <source src="items/{item_name}"> 158 </audio> 159 </td> 160 <td> 161 <audio controls> 162 <source src="items/{item_name + '.orig.wav'}"> 163 </audio> 164 </td> 165 </tr> 166 """) 167 168 # footer 169 f.write(""" 170 </table> 171 </div> 172 """) 173 174 175def create_html(output_folder, results, title, metric): 176 177 html_folder = output_folder 178 items_folder = os.path.join(html_folder, 'items') 179 os.makedirs(html_folder, exist_ok=True) 180 os.makedirs(items_folder, exist_ok=True) 181 182 with open(os.path.join(html_folder, 'index.html'), 'w') as f: 183 # header and title 184 f.write(f""" 185 <!DOCTYPE html> 186 <html lang="en"> 187 <head> 188 <meta charset="utf-8"> 189 <title>{title}</title> 190 <style> 191 article {{ 192 align-items: flex-start; 193 display: flex; 194 flex-wrap: wrap; 195 gap: 4em; 196 }} 197 html {{ 198 box-sizing: border-box; 199 font-family: "Amazon Ember", "Source Sans", "Verdana", "Calibri", sans-serif; 200 padding: 2em; 201 }} 202 td {{ 203 padding: 3px 7px; 204 text-align: center; 205 }} 206 td:first-child {{ 207 text-align: end; 208 }} 209 th {{ 210 background: #ff9900; 211 color: #000; 212 font-size: 1.2em; 213 padding: 7px 7px; 214 }} 215 </style> 216 </head> 217 </body> 218 <h1>{title}</h1> 219 <article> 220 """) 221 222 # top 20 223 add_audio_table(f, html_folder, results[:-21: -1], "Top 20", metric) 224 225 # 20 around median 226 N = len(results) // 2 227 add_audio_table(f, html_folder, results[N + 10 : N - 10: -1], "Median 20", metric) 228 229 # flop 20 230 add_audio_table(f, html_folder, results[:20], "Flop 20", metric) 231 232 # footer 233 f.write(""" 234 </article> 235 </body> 236 </html> 237 """) 238 239metric_sorting_signs = { 240 'pesq' : 1, 241 'pitch_error' : -1, 242 'voicing_error' : -1 243} 244 245def is_valid_result(data, metrics): 246 if not isinstance(data, dict): 247 return False 248 249 for metric in metrics: 250 if not metric in data: 251 return False 252 253 return True 254 255 256def evaluate_results(output_folder, results, metric): 257 258 results = sorted(results, key=lambda x : metric_sorting_signs[metric] * x[1]) 259 with open(os.path.join(args.output_folder, f'scores_{metric}.txt'), 'w') as f: 260 for result in results: 261 f.write(f"{os.path.relpath(result[0], args.output_folder)} {result[1]}\n") 262 263 264 # some statistics 265 mean = sum([r[1] for r in results]) / len(results) 266 top_mean = sum([r[1] for r in results[-20:]]) / 20 267 bottom_mean = sum([r[1] for r in results[:20]]) / 20 268 269 with open(os.path.join(args.output_folder, f'stats_{metric}.txt'), 'w') as f: 270 f.write(f"mean score: {mean}\n") 271 f.write(f"bottom mean score: {bottom_mean}\n") 272 f.write(f"top mean score: {top_mean}\n") 273 274 print(f"\nmean score: {mean}") 275 print(f"bottom mean score: {bottom_mean}") 276 print(f"top mean score: {top_mean}\n") 277 278 # create output html 279 create_html(os.path.join(output_folder, 'html', metric), results, setup['test'], metric) 280 281if __name__ == "__main__": 282 args = parser.parse_args() 283 284 # check for sox 285 if not check_for_sox_in_path(): 286 raise RuntimeError("script requires sox") 287 288 289 # prepare output folder 290 if os.path.exists(args.output_folder): 291 print("warning: output folder exists") 292 293 reply = input('continue? (y/n): ') 294 while reply not in {'y', 'n'}: 295 reply = input('continue? (y/n): ') 296 297 if reply == 'n': 298 os._exit() 299 else: 300 # start with a clean sleight 301 shutil.rmtree(args.output_folder) 302 303 os.makedirs(args.output_folder, exist_ok=True) 304 305 # extract metrics 306 metrics = args.metrics.split(",") 307 for metric in metrics: 308 if not metric in metric_sorting_signs: 309 print(f"unknown metric {metric}") 310 args.usage() 311 312 # read setup 313 print(f"loading {args.setup}...") 314 with open(args.setup, "r") as f: 315 setup = yaml.load(f.read(), yaml.FullLoader) 316 317 model_commands = setup['processing'] 318 319 print("\nfound the following model commands:") 320 for command in model_commands: 321 print(command.format(INPUT='input.wav', OUTPUT='output.wav', PLCFILE='input_is_lost.txt')) 322 323 # store setup to output folder 324 setup['input'] = os.path.abspath(args.input_folder) 325 setup['output'] = os.path.abspath(args.output_folder) 326 setup['seed'] = args.seed 327 with open(os.path.join(args.output_folder, 'setup.yml'), 'w') as f: 328 yaml.dump(setup, f) 329 330 # get input 331 print(f"\nCollecting audio files from {args.input_folder}...") 332 file_list = get_wave_file_list(args.input_folder, check_for_features=False) 333 print(f"...{len(file_list)} files found\n") 334 335 # sample from file list 336 file_list = sorted(file_list) 337 random.seed(args.seed) 338 random.shuffle(file_list) 339 num_testitems = min(args.num_testitems, len(file_list)) 340 file_list = file_list[:num_testitems] 341 342 343 print(f"\nlaunching test on {num_testitems} items...") 344 # helper function for parallel processing 345 def func(input_path): 346 output_path = get_output_path(args.input_folder, input_path, args.output_folder) 347 348 try: 349 rval = run_processing_chain(input_path, output_path, model_commands, args.fs, metrics=metrics, plc_suffix=args.plc_suffix, verbose=args.verbose) 350 except: 351 rval = (input_path, -1) 352 353 return rval 354 355 with multiprocessing.Pool(args.num_workers) as p: 356 results = p.map(func, file_list) 357 358 results_dict = dict() 359 for name, values in results: 360 if is_valid_result(values, metrics): 361 results_dict[name] = values 362 363 print(results_dict) 364 365 # evaluating results 366 num_failures = num_testitems - len(results_dict) 367 print(f"\nprocessing of {num_failures} items failed\n") 368 369 for metric in metrics: 370 print(metric) 371 evaluate_results( 372 args.output_folder, 373 [(name, value[metric]) for name, value in results_dict.items()], 374 metric 375 )