1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import os 31import argparse 32import tempfile 33import shutil 34 35import pandas as pd 36from scipy.spatial.distance import cdist 37from scipy.io import wavfile 38import numpy as np 39 40from nomad_audio.nomad import Nomad 41 42 43parser = argparse.ArgumentParser() 44parser.add_argument('folder', type=str, help='folder with processed items') 45parser.add_argument('--full-reference', action='store_true', help='use NOMAD as full-reference metric') 46parser.add_argument('--device', type=str, default=None, help='device for Nomad') 47 48 49def get_bitrates(folder): 50 with open(os.path.join(folder, 'bitrates.txt')) as f: 51 x = f.read() 52 53 bitrates = [int(y) for y in x.rstrip('\n').split()] 54 55 return bitrates 56 57def get_itemlist(folder): 58 with open(os.path.join(folder, 'items.txt')) as f: 59 lines = f.readlines() 60 61 items = [x.split()[0] for x in lines] 62 63 return items 64 65 66def nomad_wrapper(ref_folder, deg_folder, full_reference=False, ref_embeddings=None, device=None): 67 model = Nomad(device=device) 68 if not full_reference: 69 results = model.predict(nmr=ref_folder, deg=deg_folder)[0].to_dict()['NOMAD'] 70 return results, None 71 else: 72 if ref_embeddings is None: 73 print(f"Computing reference embeddings from {ref_folder}") 74 ref_data = pd.DataFrame(sorted(os.listdir(ref_folder))) 75 ref_data.columns = ['filename'] 76 ref_data['filename'] = [os.path.join(ref_folder, x) for x in ref_data['filename']] 77 ref_embeddings = model.get_embeddings_csv(model.model, ref_data).set_index('filename') 78 79 print(f"Computing degraded embeddings from {deg_folder}") 80 deg_data = pd.DataFrame(sorted(os.listdir(deg_folder))) 81 deg_data.columns = ['filename'] 82 deg_data['filename'] = [os.path.join(deg_folder, x) for x in deg_data['filename']] 83 deg_embeddings = model.get_embeddings_csv(model.model, deg_data).set_index('filename') 84 85 dist = np.diag(cdist(ref_embeddings, deg_embeddings)) # wasteful 86 test_files = [x.split('/')[-1].split('.')[0] for x in deg_embeddings.index] 87 88 results = dict(zip(test_files, dist)) 89 90 return results, ref_embeddings 91 92 93 94 95def nomad_process_all(folder, full_reference=False, device=None): 96 bitrates = get_bitrates(folder) 97 items = get_itemlist(folder) 98 with tempfile.TemporaryDirectory() as dir: 99 cleandir = os.path.join(dir, 'clean') 100 opusdir = os.path.join(dir, 'opus') 101 lacedir = os.path.join(dir, 'lace') 102 nolacedir = os.path.join(dir, 'nolace') 103 104 # prepare files 105 for d in [cleandir, opusdir, lacedir, nolacedir]: os.makedirs(d) 106 for br in bitrates: 107 for item in items: 108 for cond in ['clean', 'opus', 'lace', 'nolace']: 109 shutil.copyfile(os.path.join(folder, cond, f"{item}_{br}_{cond}.wav"), os.path.join(dir, cond, f"{item}_{br}.wav")) 110 111 nomad_opus, ref_embeddings = nomad_wrapper(cleandir, opusdir, full_reference=full_reference, ref_embeddings=None) 112 nomad_lace, ref_embeddings = nomad_wrapper(cleandir, lacedir, full_reference=full_reference, ref_embeddings=ref_embeddings) 113 nomad_nolace, ref_embeddings = nomad_wrapper(cleandir, nolacedir, full_reference=full_reference, ref_embeddings=ref_embeddings) 114 115 results = dict() 116 for br in bitrates: 117 results[br] = np.zeros((len(items), 3)) 118 for i, item in enumerate(items): 119 key = f"{item}_{br}" 120 results[br][i, 0] = nomad_opus[key] 121 results[br][i, 1] = nomad_lace[key] 122 results[br][i, 2] = nomad_nolace[key] 123 124 return results 125 126 127 128if __name__ == "__main__": 129 args = parser.parse_args() 130 131 items = get_itemlist(args.folder) 132 bitrates = get_bitrates(args.folder) 133 134 results = nomad_process_all(args.folder, full_reference=args.full_reference, device=args.device) 135 136 np.save(os.path.join(args.folder, f'results_nomad.npy'), results) 137 138 print("Done.") 139