xref: /aosp_15_r20/external/webrtc/modules/audio_processing/agc2/rnn_vad/spectral_features.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/agc2/rnn_vad/spectral_features.h"
12 
13 #include <algorithm>
14 #include <cmath>
15 #include <limits>
16 #include <numeric>
17 
18 #include "rtc_base/checks.h"
19 #include "rtc_base/numerics/safe_compare.h"
20 
21 namespace webrtc {
22 namespace rnn_vad {
23 namespace {
24 
25 constexpr float kSilenceThreshold = 0.04f;
26 
27 // Computes the new cepstral difference stats and pushes them into the passed
28 // symmetric matrix buffer.
UpdateCepstralDifferenceStats(rtc::ArrayView<const float,kNumBands> new_cepstral_coeffs,const RingBuffer<float,kNumBands,kCepstralCoeffsHistorySize> & ring_buf,SymmetricMatrixBuffer<float,kCepstralCoeffsHistorySize> * sym_matrix_buf)29 void UpdateCepstralDifferenceStats(
30     rtc::ArrayView<const float, kNumBands> new_cepstral_coeffs,
31     const RingBuffer<float, kNumBands, kCepstralCoeffsHistorySize>& ring_buf,
32     SymmetricMatrixBuffer<float, kCepstralCoeffsHistorySize>* sym_matrix_buf) {
33   RTC_DCHECK(sym_matrix_buf);
34   // Compute the new cepstral distance stats.
35   std::array<float, kCepstralCoeffsHistorySize - 1> distances;
36   for (int i = 0; i < kCepstralCoeffsHistorySize - 1; ++i) {
37     const int delay = i + 1;
38     auto old_cepstral_coeffs = ring_buf.GetArrayView(delay);
39     distances[i] = 0.f;
40     for (int k = 0; k < kNumBands; ++k) {
41       const float c = new_cepstral_coeffs[k] - old_cepstral_coeffs[k];
42       distances[i] += c * c;
43     }
44   }
45   // Push the new spectral distance stats into the symmetric matrix buffer.
46   sym_matrix_buf->Push(distances);
47 }
48 
49 // Computes the first half of the Vorbis window.
ComputeScaledHalfVorbisWindow(float scaling=1.f)50 std::array<float, kFrameSize20ms24kHz / 2> ComputeScaledHalfVorbisWindow(
51     float scaling = 1.f) {
52   constexpr int kHalfSize = kFrameSize20ms24kHz / 2;
53   std::array<float, kHalfSize> half_window{};
54   for (int i = 0; i < kHalfSize; ++i) {
55     half_window[i] =
56         scaling *
57         std::sin(0.5 * kPi * std::sin(0.5 * kPi * (i + 0.5) / kHalfSize) *
58                  std::sin(0.5 * kPi * (i + 0.5) / kHalfSize));
59   }
60   return half_window;
61 }
62 
63 // Computes the forward FFT on a 20 ms frame to which a given window function is
64 // applied. The Fourier coefficient corresponding to the Nyquist frequency is
65 // set to zero (it is never used and this allows to simplify the code).
ComputeWindowedForwardFft(rtc::ArrayView<const float,kFrameSize20ms24kHz> frame,const std::array<float,kFrameSize20ms24kHz/2> & half_window,Pffft::FloatBuffer * fft_input_buffer,Pffft::FloatBuffer * fft_output_buffer,Pffft * fft)66 void ComputeWindowedForwardFft(
67     rtc::ArrayView<const float, kFrameSize20ms24kHz> frame,
68     const std::array<float, kFrameSize20ms24kHz / 2>& half_window,
69     Pffft::FloatBuffer* fft_input_buffer,
70     Pffft::FloatBuffer* fft_output_buffer,
71     Pffft* fft) {
72   RTC_DCHECK_EQ(frame.size(), 2 * half_window.size());
73   // Apply windowing.
74   auto in = fft_input_buffer->GetView();
75   for (int i = 0, j = kFrameSize20ms24kHz - 1;
76        rtc::SafeLt(i, half_window.size()); ++i, --j) {
77     in[i] = frame[i] * half_window[i];
78     in[j] = frame[j] * half_window[i];
79   }
80   fft->ForwardTransform(*fft_input_buffer, fft_output_buffer, /*ordered=*/true);
81   // Set the Nyquist frequency coefficient to zero.
82   auto out = fft_output_buffer->GetView();
83   out[1] = 0.f;
84 }
85 
86 }  // namespace
87 
SpectralFeaturesExtractor()88 SpectralFeaturesExtractor::SpectralFeaturesExtractor()
89     : half_window_(ComputeScaledHalfVorbisWindow(
90           1.f / static_cast<float>(kFrameSize20ms24kHz))),
91       fft_(kFrameSize20ms24kHz, Pffft::FftType::kReal),
92       fft_buffer_(fft_.CreateBuffer()),
93       reference_frame_fft_(fft_.CreateBuffer()),
94       lagged_frame_fft_(fft_.CreateBuffer()),
95       dct_table_(ComputeDctTable()) {}
96 
97 SpectralFeaturesExtractor::~SpectralFeaturesExtractor() = default;
98 
Reset()99 void SpectralFeaturesExtractor::Reset() {
100   cepstral_coeffs_ring_buf_.Reset();
101   cepstral_diffs_buf_.Reset();
102 }
103 
CheckSilenceComputeFeatures(rtc::ArrayView<const float,kFrameSize20ms24kHz> reference_frame,rtc::ArrayView<const float,kFrameSize20ms24kHz> lagged_frame,rtc::ArrayView<float,kNumBands-kNumLowerBands> higher_bands_cepstrum,rtc::ArrayView<float,kNumLowerBands> average,rtc::ArrayView<float,kNumLowerBands> first_derivative,rtc::ArrayView<float,kNumLowerBands> second_derivative,rtc::ArrayView<float,kNumLowerBands> bands_cross_corr,float * variability)104 bool SpectralFeaturesExtractor::CheckSilenceComputeFeatures(
105     rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame,
106     rtc::ArrayView<const float, kFrameSize20ms24kHz> lagged_frame,
107     rtc::ArrayView<float, kNumBands - kNumLowerBands> higher_bands_cepstrum,
108     rtc::ArrayView<float, kNumLowerBands> average,
109     rtc::ArrayView<float, kNumLowerBands> first_derivative,
110     rtc::ArrayView<float, kNumLowerBands> second_derivative,
111     rtc::ArrayView<float, kNumLowerBands> bands_cross_corr,
112     float* variability) {
113   // Compute the Opus band energies for the reference frame.
114   ComputeWindowedForwardFft(reference_frame, half_window_, fft_buffer_.get(),
115                             reference_frame_fft_.get(), &fft_);
116   spectral_correlator_.ComputeAutoCorrelation(
117       reference_frame_fft_->GetConstView(), reference_frame_bands_energy_);
118   // Check if the reference frame has silence.
119   const float tot_energy =
120       std::accumulate(reference_frame_bands_energy_.begin(),
121                       reference_frame_bands_energy_.end(), 0.f);
122   if (tot_energy < kSilenceThreshold) {
123     return true;
124   }
125   // Compute the Opus band energies for the lagged frame.
126   ComputeWindowedForwardFft(lagged_frame, half_window_, fft_buffer_.get(),
127                             lagged_frame_fft_.get(), &fft_);
128   spectral_correlator_.ComputeAutoCorrelation(lagged_frame_fft_->GetConstView(),
129                                               lagged_frame_bands_energy_);
130   // Log of the band energies for the reference frame.
131   std::array<float, kNumBands> log_bands_energy;
132   ComputeSmoothedLogMagnitudeSpectrum(reference_frame_bands_energy_,
133                                       log_bands_energy);
134   // Reference frame cepstrum.
135   std::array<float, kNumBands> cepstrum;
136   ComputeDct(log_bands_energy, dct_table_, cepstrum);
137   // Ad-hoc correction terms for the first two cepstral coefficients.
138   cepstrum[0] -= 12.f;
139   cepstrum[1] -= 4.f;
140   // Update the ring buffer and the cepstral difference stats.
141   cepstral_coeffs_ring_buf_.Push(cepstrum);
142   UpdateCepstralDifferenceStats(cepstrum, cepstral_coeffs_ring_buf_,
143                                 &cepstral_diffs_buf_);
144   // Write the higher bands cepstral coefficients.
145   RTC_DCHECK_EQ(cepstrum.size() - kNumLowerBands, higher_bands_cepstrum.size());
146   std::copy(cepstrum.begin() + kNumLowerBands, cepstrum.end(),
147             higher_bands_cepstrum.begin());
148   // Compute and write remaining features.
149   ComputeAvgAndDerivatives(average, first_derivative, second_derivative);
150   ComputeNormalizedCepstralCorrelation(bands_cross_corr);
151   RTC_DCHECK(variability);
152   *variability = ComputeVariability();
153   return false;
154 }
155 
ComputeAvgAndDerivatives(rtc::ArrayView<float,kNumLowerBands> average,rtc::ArrayView<float,kNumLowerBands> first_derivative,rtc::ArrayView<float,kNumLowerBands> second_derivative) const156 void SpectralFeaturesExtractor::ComputeAvgAndDerivatives(
157     rtc::ArrayView<float, kNumLowerBands> average,
158     rtc::ArrayView<float, kNumLowerBands> first_derivative,
159     rtc::ArrayView<float, kNumLowerBands> second_derivative) const {
160   auto curr = cepstral_coeffs_ring_buf_.GetArrayView(0);
161   auto prev1 = cepstral_coeffs_ring_buf_.GetArrayView(1);
162   auto prev2 = cepstral_coeffs_ring_buf_.GetArrayView(2);
163   RTC_DCHECK_EQ(average.size(), first_derivative.size());
164   RTC_DCHECK_EQ(first_derivative.size(), second_derivative.size());
165   RTC_DCHECK_LE(average.size(), curr.size());
166   for (int i = 0; rtc::SafeLt(i, average.size()); ++i) {
167     // Average, kernel: [1, 1, 1].
168     average[i] = curr[i] + prev1[i] + prev2[i];
169     // First derivative, kernel: [1, 0, - 1].
170     first_derivative[i] = curr[i] - prev2[i];
171     // Second derivative, Laplacian kernel: [1, -2, 1].
172     second_derivative[i] = curr[i] - 2 * prev1[i] + prev2[i];
173   }
174 }
175 
ComputeNormalizedCepstralCorrelation(rtc::ArrayView<float,kNumLowerBands> bands_cross_corr)176 void SpectralFeaturesExtractor::ComputeNormalizedCepstralCorrelation(
177     rtc::ArrayView<float, kNumLowerBands> bands_cross_corr) {
178   spectral_correlator_.ComputeCrossCorrelation(
179       reference_frame_fft_->GetConstView(), lagged_frame_fft_->GetConstView(),
180       bands_cross_corr_);
181   // Normalize.
182   for (int i = 0; rtc::SafeLt(i, bands_cross_corr_.size()); ++i) {
183     bands_cross_corr_[i] =
184         bands_cross_corr_[i] /
185         std::sqrt(0.001f + reference_frame_bands_energy_[i] *
186                                lagged_frame_bands_energy_[i]);
187   }
188   // Cepstrum.
189   ComputeDct(bands_cross_corr_, dct_table_, bands_cross_corr);
190   // Ad-hoc correction terms for the first two cepstral coefficients.
191   bands_cross_corr[0] -= 1.3f;
192   bands_cross_corr[1] -= 0.9f;
193 }
194 
ComputeVariability() const195 float SpectralFeaturesExtractor::ComputeVariability() const {
196   // Compute cepstral variability score.
197   float variability = 0.f;
198   for (int delay1 = 0; delay1 < kCepstralCoeffsHistorySize; ++delay1) {
199     float min_dist = std::numeric_limits<float>::max();
200     for (int delay2 = 0; delay2 < kCepstralCoeffsHistorySize; ++delay2) {
201       if (delay1 == delay2)  // The distance would be 0.
202         continue;
203       min_dist =
204           std::min(min_dist, cepstral_diffs_buf_.GetValue(delay1, delay2));
205     }
206     variability += min_dist;
207   }
208   // Normalize (based on training set stats).
209   // TODO(bugs.webrtc.org/10480): Isolate normalization from feature extraction.
210   return variability / kCepstralCoeffsHistorySize - 2.1f;
211 }
212 
213 }  // namespace rnn_vad
214 }  // namespace webrtc
215