xref: /aosp_15_r20/external/webrtc/modules/audio_processing/agc2/rnn_vad/spectral_features_internal.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/agc2/rnn_vad/spectral_features_internal.h"
12 
13 #include <algorithm>
14 #include <cmath>
15 #include <cstddef>
16 
17 #include "rtc_base/checks.h"
18 #include "rtc_base/numerics/safe_compare.h"
19 
20 namespace webrtc {
21 namespace rnn_vad {
22 namespace {
23 
24 // Weights for each FFT coefficient for each Opus band (Nyquist frequency
25 // excluded). The size of each band is specified in
26 // `kOpusScaleNumBins24kHz20ms`.
27 constexpr std::array<float, kFrameSize20ms24kHz / 2> kOpusBandWeights24kHz20ms =
28     {{
29         0.f,       0.25f,      0.5f,       0.75f,  // Band 0
30         0.f,       0.25f,      0.5f,       0.75f,  // Band 1
31         0.f,       0.25f,      0.5f,       0.75f,  // Band 2
32         0.f,       0.25f,      0.5f,       0.75f,  // Band 3
33         0.f,       0.25f,      0.5f,       0.75f,  // Band 4
34         0.f,       0.25f,      0.5f,       0.75f,  // Band 5
35         0.f,       0.25f,      0.5f,       0.75f,  // Band 6
36         0.f,       0.25f,      0.5f,       0.75f,  // Band 7
37         0.f,       0.125f,     0.25f,      0.375f,    0.5f,
38         0.625f,    0.75f,      0.875f,  // Band 8
39         0.f,       0.125f,     0.25f,      0.375f,    0.5f,
40         0.625f,    0.75f,      0.875f,  // Band 9
41         0.f,       0.125f,     0.25f,      0.375f,    0.5f,
42         0.625f,    0.75f,      0.875f,  // Band 10
43         0.f,       0.125f,     0.25f,      0.375f,    0.5f,
44         0.625f,    0.75f,      0.875f,  // Band 11
45         0.f,       0.0625f,    0.125f,     0.1875f,   0.25f,
46         0.3125f,   0.375f,     0.4375f,    0.5f,      0.5625f,
47         0.625f,    0.6875f,    0.75f,      0.8125f,   0.875f,
48         0.9375f,  // Band 12
49         0.f,       0.0625f,    0.125f,     0.1875f,   0.25f,
50         0.3125f,   0.375f,     0.4375f,    0.5f,      0.5625f,
51         0.625f,    0.6875f,    0.75f,      0.8125f,   0.875f,
52         0.9375f,  // Band 13
53         0.f,       0.0625f,    0.125f,     0.1875f,   0.25f,
54         0.3125f,   0.375f,     0.4375f,    0.5f,      0.5625f,
55         0.625f,    0.6875f,    0.75f,      0.8125f,   0.875f,
56         0.9375f,  // Band 14
57         0.f,       0.0416667f, 0.0833333f, 0.125f,    0.166667f,
58         0.208333f, 0.25f,      0.291667f,  0.333333f, 0.375f,
59         0.416667f, 0.458333f,  0.5f,       0.541667f, 0.583333f,
60         0.625f,    0.666667f,  0.708333f,  0.75f,     0.791667f,
61         0.833333f, 0.875f,     0.916667f,  0.958333f,  // Band 15
62         0.f,       0.0416667f, 0.0833333f, 0.125f,    0.166667f,
63         0.208333f, 0.25f,      0.291667f,  0.333333f, 0.375f,
64         0.416667f, 0.458333f,  0.5f,       0.541667f, 0.583333f,
65         0.625f,    0.666667f,  0.708333f,  0.75f,     0.791667f,
66         0.833333f, 0.875f,     0.916667f,  0.958333f,  // Band 16
67         0.f,       0.03125f,   0.0625f,    0.09375f,  0.125f,
68         0.15625f,  0.1875f,    0.21875f,   0.25f,     0.28125f,
69         0.3125f,   0.34375f,   0.375f,     0.40625f,  0.4375f,
70         0.46875f,  0.5f,       0.53125f,   0.5625f,   0.59375f,
71         0.625f,    0.65625f,   0.6875f,    0.71875f,  0.75f,
72         0.78125f,  0.8125f,    0.84375f,   0.875f,    0.90625f,
73         0.9375f,   0.96875f,  // Band 17
74         0.f,       0.0208333f, 0.0416667f, 0.0625f,   0.0833333f,
75         0.104167f, 0.125f,     0.145833f,  0.166667f, 0.1875f,
76         0.208333f, 0.229167f,  0.25f,      0.270833f, 0.291667f,
77         0.3125f,   0.333333f,  0.354167f,  0.375f,    0.395833f,
78         0.416667f, 0.4375f,    0.458333f,  0.479167f, 0.5f,
79         0.520833f, 0.541667f,  0.5625f,    0.583333f, 0.604167f,
80         0.625f,    0.645833f,  0.666667f,  0.6875f,   0.708333f,
81         0.729167f, 0.75f,      0.770833f,  0.791667f, 0.8125f,
82         0.833333f, 0.854167f,  0.875f,     0.895833f, 0.916667f,
83         0.9375f,   0.958333f,  0.979167f  // Band 18
84     }};
85 
86 }  // namespace
87 
SpectralCorrelator()88 SpectralCorrelator::SpectralCorrelator()
89     : weights_(kOpusBandWeights24kHz20ms.begin(),
90                kOpusBandWeights24kHz20ms.end()) {}
91 
92 SpectralCorrelator::~SpectralCorrelator() = default;
93 
ComputeAutoCorrelation(rtc::ArrayView<const float> x,rtc::ArrayView<float,kOpusBands24kHz> auto_corr) const94 void SpectralCorrelator::ComputeAutoCorrelation(
95     rtc::ArrayView<const float> x,
96     rtc::ArrayView<float, kOpusBands24kHz> auto_corr) const {
97   ComputeCrossCorrelation(x, x, auto_corr);
98 }
99 
ComputeCrossCorrelation(rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float,kOpusBands24kHz> cross_corr) const100 void SpectralCorrelator::ComputeCrossCorrelation(
101     rtc::ArrayView<const float> x,
102     rtc::ArrayView<const float> y,
103     rtc::ArrayView<float, kOpusBands24kHz> cross_corr) const {
104   RTC_DCHECK_EQ(x.size(), kFrameSize20ms24kHz);
105   RTC_DCHECK_EQ(x.size(), y.size());
106   RTC_DCHECK_EQ(x[1], 0.f) << "The Nyquist coefficient must be zeroed.";
107   RTC_DCHECK_EQ(y[1], 0.f) << "The Nyquist coefficient must be zeroed.";
108   constexpr auto kOpusScaleNumBins24kHz20ms = GetOpusScaleNumBins24kHz20ms();
109   int k = 0;  // Next Fourier coefficient index.
110   cross_corr[0] = 0.f;
111   for (int i = 0; i < kOpusBands24kHz - 1; ++i) {
112     cross_corr[i + 1] = 0.f;
113     for (int j = 0; j < kOpusScaleNumBins24kHz20ms[i]; ++j) {  // Band size.
114       const float v = x[2 * k] * y[2 * k] + x[2 * k + 1] * y[2 * k + 1];
115       const float tmp = weights_[k] * v;
116       cross_corr[i] += v - tmp;
117       cross_corr[i + 1] += tmp;
118       k++;
119     }
120   }
121   cross_corr[0] *= 2.f;  // The first band only gets half contribution.
122   RTC_DCHECK_EQ(k, kFrameSize20ms24kHz / 2);  // Nyquist coefficient never used.
123 }
124 
ComputeSmoothedLogMagnitudeSpectrum(rtc::ArrayView<const float> bands_energy,rtc::ArrayView<float,kNumBands> log_bands_energy)125 void ComputeSmoothedLogMagnitudeSpectrum(
126     rtc::ArrayView<const float> bands_energy,
127     rtc::ArrayView<float, kNumBands> log_bands_energy) {
128   RTC_DCHECK_LE(bands_energy.size(), kNumBands);
129   constexpr float kOneByHundred = 1e-2f;
130   constexpr float kLogOneByHundred = -2.f;
131   // Init.
132   float log_max = kLogOneByHundred;
133   float follow = kLogOneByHundred;
134   const auto smooth = [&log_max, &follow](float x) {
135     x = std::max(log_max - 7.f, std::max(follow - 1.5f, x));
136     log_max = std::max(log_max, x);
137     follow = std::max(follow - 1.5f, x);
138     return x;
139   };
140   // Smoothing over the bands for which the band energy is defined.
141   for (int i = 0; rtc::SafeLt(i, bands_energy.size()); ++i) {
142     log_bands_energy[i] = smooth(std::log10(kOneByHundred + bands_energy[i]));
143   }
144   // Smoothing over the remaining bands (zero energy).
145   for (int i = bands_energy.size(); i < kNumBands; ++i) {
146     log_bands_energy[i] = smooth(kLogOneByHundred);
147   }
148 }
149 
ComputeDctTable()150 std::array<float, kNumBands * kNumBands> ComputeDctTable() {
151   std::array<float, kNumBands * kNumBands> dct_table;
152   const double k = std::sqrt(0.5);
153   for (int i = 0; i < kNumBands; ++i) {
154     for (int j = 0; j < kNumBands; ++j)
155       dct_table[i * kNumBands + j] = std::cos((i + 0.5) * j * kPi / kNumBands);
156     dct_table[i * kNumBands] *= k;
157   }
158   return dct_table;
159 }
160 
ComputeDct(rtc::ArrayView<const float> in,rtc::ArrayView<const float,kNumBands * kNumBands> dct_table,rtc::ArrayView<float> out)161 void ComputeDct(rtc::ArrayView<const float> in,
162                 rtc::ArrayView<const float, kNumBands * kNumBands> dct_table,
163                 rtc::ArrayView<float> out) {
164   // DCT scaling factor - i.e., sqrt(2 / kNumBands).
165   constexpr float kDctScalingFactor = 0.301511345f;
166   constexpr float kDctScalingFactorError =
167       kDctScalingFactor * kDctScalingFactor -
168       2.f / static_cast<float>(kNumBands);
169   static_assert(
170       (kDctScalingFactorError >= 0.f && kDctScalingFactorError < 1e-1f) ||
171           (kDctScalingFactorError < 0.f && kDctScalingFactorError > -1e-1f),
172       "kNumBands changed and kDctScalingFactor has not been updated.");
173   RTC_DCHECK_NE(in.data(), out.data()) << "In-place DCT is not supported.";
174   RTC_DCHECK_LE(in.size(), kNumBands);
175   RTC_DCHECK_LE(1, out.size());
176   RTC_DCHECK_LE(out.size(), in.size());
177   for (int i = 0; rtc::SafeLt(i, out.size()); ++i) {
178     out[i] = 0.f;
179     for (int j = 0; rtc::SafeLt(j, in.size()); ++j) {
180       out[i] += in[j] * dct_table[j * kNumBands + i];
181     }
182     // TODO(bugs.webrtc.org/10480): Scaling factor in the DCT table.
183     out[i] *= kDctScalingFactor;
184   }
185 }
186 
187 }  // namespace rnn_vad
188 }  // namespace webrtc
189