xref: /aosp_15_r20/external/armnn/python/pyarmnn/examples/tests/test_mfcc.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4import os
5import numpy as np
6import pytest
7import collections
8
9from context import mfcc
10from context import wav2letter_mfcc
11from context import audio_capture
12
13# Elements relevant to MFCC filter bank & feature extraction
14MFCC_TEST_PARAMS = collections.namedtuple('mfcc_test_params',
15                                          ['algo_params', 'mfcc_constructor', 'audio_proc_constructor'])
16
17
18def kws_test_params():
19    kws_algo_params = mfcc.MFCCParams(sampling_freq=16000, num_fbank_bins=40, mel_lo_freq=20, mel_hi_freq=4000,
20                                      num_mfcc_feats=10, frame_len=640, use_htk_method=True, n_fft=1024)
21    return MFCC_TEST_PARAMS(kws_algo_params, mfcc.MFCC, mfcc.AudioPreprocessor)
22
23
24def asr_test_params():
25    asr_algo_params = mfcc.MFCCParams(sampling_freq=16000, num_fbank_bins=128, mel_lo_freq=0, mel_hi_freq=8000,
26                                      num_mfcc_feats=13, frame_len=512, use_htk_method=False, n_fft=512)
27    return MFCC_TEST_PARAMS(asr_algo_params, wav2letter_mfcc.Wav2LetterMFCC, wav2letter_mfcc.W2LAudioPreprocessor)
28
29
30def kws_cap_params():
31    return audio_capture.AudioCaptureParams(dtype=np.float32, overlap=0, min_samples=16000, sampling_freq=16000,
32                                            mono=True)
33
34
35def asr_cap_params():
36    return audio_capture.AudioCaptureParams(dtype=np.float32, overlap=31712, min_samples=47712,
37                                            sampling_freq=16000, mono=True)
38
39
40@pytest.fixture()
41def audio_data(test_data_folder, file, audio_cap_params):
42    audio_file = os.path.join(test_data_folder, file)
43    capture = audio_capture.capture_audio(audio_file, audio_cap_params)
44    yield next(capture)
45
46
47@pytest.mark.parametrize("file", ["yes.wav", "myVoiceIsMyPassportVerifyMe04.wav"])
48@pytest.mark.parametrize("audio_cap_params", [kws_cap_params(), asr_cap_params()])
49def test_audio_file(audio_data, test_data_folder, file, audio_cap_params):
50    assert audio_data.shape == (audio_cap_params.min_samples,)
51    assert audio_data.dtype == audio_cap_params.dtype
52
53
54@pytest.mark.parametrize("mfcc_test_params, test_out", [(kws_test_params(), 25.470010570730597),
55                                                        (asr_test_params(), 0.24)])
56def test_mel_scale_function(mfcc_test_params, test_out):
57    mfcc_inst = mfcc_test_params.mfcc_constructor(mfcc_test_params.algo_params)
58    mel = mfcc_inst.mel_scale(16, mfcc_test_params.algo_params.use_htk_method)
59    assert np.isclose(mel, test_out)
60
61
62@pytest.mark.parametrize("mfcc_test_params, test_out", [(kws_test_params(), 10.008767240008943),
63                                                        (asr_test_params(), 1071.170287494467)])
64def test_inverse_mel_scale_function(mfcc_test_params, test_out):
65    mfcc_inst = mfcc_test_params.mfcc_constructor(mfcc_test_params.algo_params)
66    mel = mfcc_inst.inv_mel_scale(16, mfcc_test_params.algo_params.use_htk_method)
67    assert np.isclose(mel, test_out)
68
69
70mel_filter_test_data_kws = {0: [0.33883214, 0.80088392, 0.74663128, 0.30332531],
71                            1: [0.25336872, 0.69667469, 0.86883317, 0.44281119, 0.02493546],
72                            2: [0.13116683, 0.55718881, 0.97506454, 0.61490026, 0.21241678],
73                            5: [0.32725038, 0.69579596, 0.9417706, 0.58524989, 0.23445207],
74                            -1: [0.02433275, 0.10371618, 0.1828123, 0.26162319, 0.34015089, 0.41839743,
75                                 0.49636481, 0.57405503, 0.65147004, 0.72861179, 0.8054822, 0.88208318,
76                                 0.95841659, 0.96551568, 0.88971181, 0.81416996, 0.73888833, 0.66386514,
77                                 0.58909861, 0.514587, 0.44032856, 0.3663216, 0.29256441, 0.21905531,
78                                 0.14579264, 0.07277474]}
79
80mel_filter_test_data_asr = {0: [0.02837754],
81                            1: [0.01438901, 0.01398853],
82                            2: [0.02877802],
83                            5: [0.01478948, 0.01358806],
84                            -1: [4.82151203e-05, 9.48791110e-04, 1.84569875e-03, 2.73896782e-03,
85                                 3.62862771e-03, 4.51470746e-03, 5.22215439e-03, 4.34314914e-03,
86                                 3.46763895e-03, 2.59559614e-03, 1.72699334e-03, 8.61803536e-04]}
87
88
89@pytest.mark.parametrize("mfcc_test_params, test_out",
90                         [(kws_test_params(), mel_filter_test_data_kws),
91                          (asr_test_params(), mel_filter_test_data_asr)])
92def test_create_mel_filter_bank(mfcc_test_params, test_out):
93    mfcc_inst = mfcc_test_params.mfcc_constructor(mfcc_test_params.algo_params)
94    mel_filter_bank = mfcc_inst.create_mel_filter_bank()
95    assert len(mel_filter_bank) == mfcc_test_params.algo_params.num_fbank_bins
96    for indx, data in test_out.items():
97        assert np.allclose(mel_filter_bank[indx], data)
98
99
100mfcc_test_data_kws = (-22.671347398982626, -0.6161543999707211, 2.072326974167832,
101                      0.5813741475362223, 1.0165529747334272, 0.8581560719988703,
102                      0.4603911069624896, 0.03392820944377398, 1.1651093266902361,
103                      0.007200025869960908)
104
105mfcc_test_data_asr = (-735.46345398, 69.50331943, 16.39159347, 22.74874819, 24.84782893,
106                      10.67559303, 12.82828618, -3.51084271, 4.66633677, 10.20079095, 11.34782948, 3.90499354,
107                      9.32322384)
108
109
110@pytest.mark.parametrize("mfcc_test_params, test_out, file, audio_cap_params",
111                         [(kws_test_params(), mfcc_test_data_kws, "yes.wav", kws_cap_params()),
112                          (asr_test_params(), mfcc_test_data_asr, "myVoiceIsMyPassportVerifyMe04.wav",
113                           asr_cap_params())])
114def test_mfcc_compute_first_frame(audio_data, mfcc_test_params, test_out, file, audio_cap_params):
115    audio_data = np.array(audio_data)[0:mfcc_test_params.algo_params.frame_len]
116    mfcc_inst = mfcc_test_params.mfcc_constructor(mfcc_test_params.algo_params)
117    mfcc_feats = mfcc_inst.mfcc_compute(audio_data)
118    assert np.allclose((mfcc_feats[0:mfcc_test_params.algo_params.num_mfcc_feats]), test_out)
119
120
121extract_test_data_kws = {0: [-2.2671347e+01, -6.1615437e-01, 2.0723269e+00, 5.8137417e-01,
122                             1.0165529e+00, 8.5815609e-01, 4.6039110e-01, 3.3928208e-02,
123                             1.1651093e+00, 7.2000260e-03],
124                         1: [-23.488806, -1.1687667, 3.0548365, 1.5129884, 1.4142203,
125                             0.6869772, 1.1875846, 0.5743369, 1.202258, -0.12133602],
126                         2: [-23.909292, -1.5186096, 1.8721082, 0.7378916, 0.44974303,
127                             0.17609395, 0.5183161, 0.37109664, 0.14186797, 0.58400506],
128                         -1: [-23.752186, -0.1796912, 1.9514247, 0.32554424, 1.8425112,
129                              0.8763608, 0.78326845, 0.27808753, 0.73788685, 0.30338883]}
130
131extract_test_data_asr = {0: [-4.98830318e+00, 6.86444461e-01, 3.12024504e-01, 3.56840312e-01,
132                             3.71638149e-01, 2.71728605e-01, 2.86904365e-01, 1.71718955e-01,
133                             2.29365349e-01, 2.68381387e-01, 2.76467651e-01, 2.23998129e-01,
134                             2.62194842e-01, -1.48247385e+01, 1.21875501e+00, 4.20235842e-01,
135                             5.39400637e-01, 6.09882712e-01, 1.68513224e-01, 3.75330061e-01,
136                             8.57576132e-02, 1.92831963e-01, 1.41814977e-01, 1.57615796e-01,
137                             7.19076321e-02, 1.98729336e-01, 3.92199278e+00, -5.76856315e-01,
138                             1.17938723e-02, -9.25096497e-02, -3.59488949e-02, 1.13284402e-03,
139                             1.51282102e-01, 1.13404110e-01, -8.69824737e-02, -1.48449212e-01,
140                             -1.24230251e-01, -1.90728232e-01, -5.37525006e-02],
141                         1: [-4.96694946e+00, 6.69411421e-01, 2.86189795e-01, 3.65071595e-01,
142                             3.92671198e-01, 2.44258150e-01, 2.52177566e-01, 2.16024980e-01,
143                             2.79812217e-01, 2.79687315e-01, 2.95228422e-01, 2.83991724e-01,
144                             2.46358261e-01, -1.33618221e+01, 1.08920455e+00, 3.88707787e-01,
145                             5.05674303e-01, 6.08285785e-01, 1.68113053e-01, 3.54529470e-01,
146                             6.68609440e-02, 1.52882755e-01, 6.89579248e-02, 1.18375972e-01,
147                             5.86742274e-02, 1.15678251e-01, 1.07892036e+01, -1.07193100e+00,
148                             -2.18140319e-01, -3.35950345e-01, -2.57241666e-01, -5.54431602e-02,
149                             -8.38544443e-02, -5.79114584e-03, -2.23973781e-01, -2.91451365e-01,
150                             -2.11069033e-01, -1.90297231e-01, -2.76504964e-01],
151                         2: [-4.98664522e+00, 6.54802263e-01, 3.70355755e-01, 4.06837821e-01,
152                             4.05175537e-01, 2.29149669e-01, 2.83312678e-01, 2.17573136e-01,
153                             3.07824671e-01, 2.48388007e-01, 2.25399241e-01, 2.52003014e-01,
154                             2.83968121e-01, -1.05043650e+01, 7.91533887e-01, 3.11546475e-01,
155                             4.36079264e-01, 5.93271911e-01, 2.02480286e-01, 3.24254721e-01,
156                             6.29674867e-02, 9.67641100e-02, -1.62826646e-02, 5.47595806e-02,
157                             2.90475693e-02, 2.62522381e-02, 1.38787737e+01, -1.32597208e+00,
158                             -3.73900205e-01, -4.38065380e-01, -3.05983245e-01, 1.14390980e-02,
159                             -2.10821658e-01, -6.22789040e-02, -2.88273603e-01, -3.29794526e-01,
160                             -2.43764088e-01, -1.70954674e-01, -3.65193188e-01],
161                         -1: [-2.1894817, 1.583355, -0.45024827, 0.11657667, 0.08940444, 0.09041209,
162                              0.2003613, 0.11800499, 0.18838657, 0.29271516, 0.22758003, 0.10634928,
163                              -0.04019014, 7.203311, -2.414309, 0.28750962, -0.24222863, 0.04680864,
164                              -0.12129474, 0.18059334, 0.06250379, 0.11363743, -0.2561094, -0.08132717,
165                              -0.08500769, 0.18916495, 1.3529671, -3.7919693, 1.937804, 0.6845761,
166                              0.15381853, 0.41106734, -0.28207013, 0.2195526, 0.06716935, -0.02886542,
167                              -0.22860551, 0.24788341, 0.63940096]}
168
169
170@pytest.mark.parametrize("mfcc_test_params, model_input_size, stride, min_samples, file, audio_cap_params, test_out",
171                         [(kws_test_params(), 49, 320, 16000, "yes.wav", kws_cap_params(),
172                           extract_test_data_kws),
173                          (asr_test_params(), 296, 160, 47712, "myVoiceIsMyPassportVerifyMe04.wav", asr_cap_params(),
174                           extract_test_data_asr)])
175def test_feat_extraction_full_sized_input(audio_data,
176                                          mfcc_test_params,
177                                          model_input_size,
178                                          stride,
179                                          min_samples, file, audio_cap_params,
180                                          test_out):
181    """
182    Test out values were gathered by printing the mfcc features collected during the first full inference
183    on the test wav files. Note the extract_features() function simply calls the mfcc_compute() from previous
184    test but feeds in enough samples for an inference rather than a single frame. It also computes the 1st & 2nd
185    derivative features hence the shape (13*3 = 39).
186    Specific model_input_size and stride parameters are also required as additional arguments.
187    """
188    audio_data = np.array(audio_data)
189    # Pad with zeros to ensure min_samples for inference
190    audio_data.resize(min_samples)
191    mfcc_inst = mfcc_test_params.mfcc_constructor(mfcc_test_params.algo_params)
192    preprocessor = mfcc_test_params.audio_proc_constructor(mfcc_inst, model_input_size, stride)
193    # extract_features passes the audio data to mfcc_compute frame by frame and concatenates results
194    input_tensor = preprocessor.extract_features(audio_data)
195    assert len(input_tensor) == model_input_size
196    for indx, data in test_out.items():
197        assert np.allclose(input_tensor[indx], data)
198
199
200# Expected contents of input tensors for inference on a silent wav file
201extract_features_zeros_kws = {0: [-2.05949466e+02, -4.88498131e-15, 8.15428020e-15, -5.77315973e-15,
202                                  7.03142511e-15, -1.11022302e-14, 2.18015108e-14, -1.77635684e-15,
203                                  1.06581410e-14, 2.75335310e-14],
204                              -1: [-2.05949466e+02, -4.88498131e-15, 8.15428020e-15, -5.77315973e-15,
205                                   7.03142511e-15, -1.11022302e-14, 2.18015108e-14, -1.77635684e-15,
206                                   1.06581410e-14, 2.75335310e-14]}
207
208extract_features_zeros_asr = {
209    0: [-3.46410162e+00, 2.88675135e-01, 2.88675135e-01, 2.88675135e-01,
210        2.88675135e-01, 2.88675135e-01, 2.88675135e-01, 2.88675135e-01,
211        2.88675135e-01, 2.88675135e-01, 2.88675135e-01, 2.88675135e-01,
212        2.88675135e-01, 2.79662980e+01, 1.75638694e-15, -9.41313626e-16,
213        9.66012817e-16, -1.23221521e-15, 1.75638694e-15, -1.59035349e-15,
214        2.41503204e-15, -1.64798493e-15, 4.39096735e-16, -4.95356004e-16,
215        -2.19548368e-16, -3.55668355e-15, 8.19843971e+00, -4.28340672e-02,
216        -4.28340672e-02, -4.28340672e-02, -4.28340672e-02, -4.28340672e-02,
217        -4.28340672e-02, -4.28340672e-02, -4.28340672e-02, -4.28340672e-02,
218        -4.28340672e-02, -4.28340672e-02, -4.28340672e-02],
219    - 1: [-3.46410162e+00, 2.88675135e-01, 2.88675135e-01, 2.88675135e-01,
220          2.88675135e-01, 2.88675135e-01, 2.88675135e-01, 2.88675135e-01,
221          2.88675135e-01, 2.88675135e-01, 2.88675135e-01, 2.88675135e-01,
222          2.88675135e-01, 2.79662980e+01, 1.75638694e-15, -9.41313626e-16,
223          9.66012817e-16, -1.23221521e-15, 1.75638694e-15, -1.59035349e-15,
224          2.41503204e-15, -1.64798493e-15, 4.39096735e-16, -4.95356004e-16,
225          -2.19548368e-16, -3.55668355e-15, 8.19843971e+00, -4.28340672e-02,
226          -4.28340672e-02, -4.28340672e-02, -4.28340672e-02, -4.28340672e-02,
227          -4.28340672e-02, -4.28340672e-02, -4.28340672e-02, -4.28340672e-02,
228          -4.28340672e-02, -4.28340672e-02, -4.28340672e-02]}
229
230
231@pytest.mark.parametrize("mfcc_test_params,model_input_size, stride, min_samples, test_out",
232                         [(kws_test_params(), 49, 320, 16000, extract_features_zeros_kws),
233                          (asr_test_params(), 296, 160, 47712, extract_features_zeros_asr)])
234def test_feat_extraction_full_sized_input_zeros(mfcc_test_params, model_input_size, stride, min_samples, test_out):
235    audio_data = np.zeros(min_samples).astype(np.float32)
236    mfcc_inst = mfcc_test_params.mfcc_constructor(mfcc_test_params.algo_params)
237
238    preprocessor = mfcc_test_params.audio_proc_constructor(mfcc_inst, model_input_size,
239                                                           stride)
240    input_tensor = preprocessor.extract_features(audio_data)
241    assert len(input_tensor) == model_input_size
242    for indx, data in test_out.items():
243        # Element 14 of feature extraction vector differs minutely during
244        # inference on a silent wav file compared to array of 0's
245        # Workarounds were to skip this sample or add large tolerance argument (atol=10)
246        assert np.allclose(input_tensor[indx][0:13], data[0:13])
247        assert np.allclose(input_tensor[indx][15:], data[15:])
248