xref: /aosp_15_r20/external/libopus/dnn/torch/neural-pitch/experiments.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2Running the experiments;
3    1. RCA vs SNR for our models, CREPE, LPCNet
4"""
5
6import argparse
7parser = argparse.ArgumentParser()
8
9parser.add_argument('ptdb_root', type=str, help='Root Directory for PTDB generated by running ptdb_process.sh ')
10parser.add_argument('output', type=str, help='Output dump file name')
11parser.add_argument('method', type=str, help='Output Directory to save experiment dumps',choices=['model','lpcnet','crepe'])
12parser.add_argument('--noise_dataset', type=str, help='Location of the Demand Datset',default = './',required=False)
13parser.add_argument('--noise_type', type=str, help='Type of additive noise',default = 'synthetic',choices=['synthetic','demand'],required=False)
14parser.add_argument('--pth_file', type=str, help='.pth file to analyze',default = './',required = False)
15parser.add_argument('--fraction_files_analyze', type=float, help='Fraction of PTDB dataset to test on',default = 1,required = False)
16parser.add_argument('--threshold_rca', type=float, help='Cent threshold when computing RCA',default = 50,required = False)
17parser.add_argument('--gpu_index', type=int, help='GPU index to use if multiple GPUs',default = 0,required = False)
18
19args = parser.parse_args()
20
21import os
22os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
23os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_index)
24
25import json
26from evaluation import cycle_eval
27
28if args.method == 'model':
29    dict_store = cycle_eval([args.pth_file], noise_type = args.noise_type, noise_dataset = args.noise_dataset, list_snr = [-20,-15,-10,-5,0,5,10,15,20], ptdb_dataset_path = args.ptdb_root,fraction = args.fraction_files_analyze,thresh = args.threshold_rca)
30else:
31    dict_store = cycle_eval([args.method], noise_type = args.noise_type, noise_dataset = args.noise_dataset, list_snr = [-20,-15,-10,-5,0,5,10,15,20], ptdb_dataset_path = args.ptdb_root,fraction = args.fraction_files_analyze,thresh = args.threshold_rca)
32
33dict_store["method"] = args.method
34if args.method == 'model':
35    dict_store['pth'] = args.pth_file
36
37with open(args.output, 'w') as fp:
38    json.dump(dict_store, fp)
39