1 /*
2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/agc2/rnn_vad/spectral_features_internal.h"
12
13 #include <algorithm>
14 #include <cmath>
15 #include <cstddef>
16
17 #include "rtc_base/checks.h"
18 #include "rtc_base/numerics/safe_compare.h"
19
20 namespace webrtc {
21 namespace rnn_vad {
22 namespace {
23
24 // Weights for each FFT coefficient for each Opus band (Nyquist frequency
25 // excluded). The size of each band is specified in
26 // `kOpusScaleNumBins24kHz20ms`.
27 constexpr std::array<float, kFrameSize20ms24kHz / 2> kOpusBandWeights24kHz20ms =
28 {{
29 0.f, 0.25f, 0.5f, 0.75f, // Band 0
30 0.f, 0.25f, 0.5f, 0.75f, // Band 1
31 0.f, 0.25f, 0.5f, 0.75f, // Band 2
32 0.f, 0.25f, 0.5f, 0.75f, // Band 3
33 0.f, 0.25f, 0.5f, 0.75f, // Band 4
34 0.f, 0.25f, 0.5f, 0.75f, // Band 5
35 0.f, 0.25f, 0.5f, 0.75f, // Band 6
36 0.f, 0.25f, 0.5f, 0.75f, // Band 7
37 0.f, 0.125f, 0.25f, 0.375f, 0.5f,
38 0.625f, 0.75f, 0.875f, // Band 8
39 0.f, 0.125f, 0.25f, 0.375f, 0.5f,
40 0.625f, 0.75f, 0.875f, // Band 9
41 0.f, 0.125f, 0.25f, 0.375f, 0.5f,
42 0.625f, 0.75f, 0.875f, // Band 10
43 0.f, 0.125f, 0.25f, 0.375f, 0.5f,
44 0.625f, 0.75f, 0.875f, // Band 11
45 0.f, 0.0625f, 0.125f, 0.1875f, 0.25f,
46 0.3125f, 0.375f, 0.4375f, 0.5f, 0.5625f,
47 0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f,
48 0.9375f, // Band 12
49 0.f, 0.0625f, 0.125f, 0.1875f, 0.25f,
50 0.3125f, 0.375f, 0.4375f, 0.5f, 0.5625f,
51 0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f,
52 0.9375f, // Band 13
53 0.f, 0.0625f, 0.125f, 0.1875f, 0.25f,
54 0.3125f, 0.375f, 0.4375f, 0.5f, 0.5625f,
55 0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f,
56 0.9375f, // Band 14
57 0.f, 0.0416667f, 0.0833333f, 0.125f, 0.166667f,
58 0.208333f, 0.25f, 0.291667f, 0.333333f, 0.375f,
59 0.416667f, 0.458333f, 0.5f, 0.541667f, 0.583333f,
60 0.625f, 0.666667f, 0.708333f, 0.75f, 0.791667f,
61 0.833333f, 0.875f, 0.916667f, 0.958333f, // Band 15
62 0.f, 0.0416667f, 0.0833333f, 0.125f, 0.166667f,
63 0.208333f, 0.25f, 0.291667f, 0.333333f, 0.375f,
64 0.416667f, 0.458333f, 0.5f, 0.541667f, 0.583333f,
65 0.625f, 0.666667f, 0.708333f, 0.75f, 0.791667f,
66 0.833333f, 0.875f, 0.916667f, 0.958333f, // Band 16
67 0.f, 0.03125f, 0.0625f, 0.09375f, 0.125f,
68 0.15625f, 0.1875f, 0.21875f, 0.25f, 0.28125f,
69 0.3125f, 0.34375f, 0.375f, 0.40625f, 0.4375f,
70 0.46875f, 0.5f, 0.53125f, 0.5625f, 0.59375f,
71 0.625f, 0.65625f, 0.6875f, 0.71875f, 0.75f,
72 0.78125f, 0.8125f, 0.84375f, 0.875f, 0.90625f,
73 0.9375f, 0.96875f, // Band 17
74 0.f, 0.0208333f, 0.0416667f, 0.0625f, 0.0833333f,
75 0.104167f, 0.125f, 0.145833f, 0.166667f, 0.1875f,
76 0.208333f, 0.229167f, 0.25f, 0.270833f, 0.291667f,
77 0.3125f, 0.333333f, 0.354167f, 0.375f, 0.395833f,
78 0.416667f, 0.4375f, 0.458333f, 0.479167f, 0.5f,
79 0.520833f, 0.541667f, 0.5625f, 0.583333f, 0.604167f,
80 0.625f, 0.645833f, 0.666667f, 0.6875f, 0.708333f,
81 0.729167f, 0.75f, 0.770833f, 0.791667f, 0.8125f,
82 0.833333f, 0.854167f, 0.875f, 0.895833f, 0.916667f,
83 0.9375f, 0.958333f, 0.979167f // Band 18
84 }};
85
86 } // namespace
87
SpectralCorrelator()88 SpectralCorrelator::SpectralCorrelator()
89 : weights_(kOpusBandWeights24kHz20ms.begin(),
90 kOpusBandWeights24kHz20ms.end()) {}
91
92 SpectralCorrelator::~SpectralCorrelator() = default;
93
ComputeAutoCorrelation(rtc::ArrayView<const float> x,rtc::ArrayView<float,kOpusBands24kHz> auto_corr) const94 void SpectralCorrelator::ComputeAutoCorrelation(
95 rtc::ArrayView<const float> x,
96 rtc::ArrayView<float, kOpusBands24kHz> auto_corr) const {
97 ComputeCrossCorrelation(x, x, auto_corr);
98 }
99
ComputeCrossCorrelation(rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float,kOpusBands24kHz> cross_corr) const100 void SpectralCorrelator::ComputeCrossCorrelation(
101 rtc::ArrayView<const float> x,
102 rtc::ArrayView<const float> y,
103 rtc::ArrayView<float, kOpusBands24kHz> cross_corr) const {
104 RTC_DCHECK_EQ(x.size(), kFrameSize20ms24kHz);
105 RTC_DCHECK_EQ(x.size(), y.size());
106 RTC_DCHECK_EQ(x[1], 0.f) << "The Nyquist coefficient must be zeroed.";
107 RTC_DCHECK_EQ(y[1], 0.f) << "The Nyquist coefficient must be zeroed.";
108 constexpr auto kOpusScaleNumBins24kHz20ms = GetOpusScaleNumBins24kHz20ms();
109 int k = 0; // Next Fourier coefficient index.
110 cross_corr[0] = 0.f;
111 for (int i = 0; i < kOpusBands24kHz - 1; ++i) {
112 cross_corr[i + 1] = 0.f;
113 for (int j = 0; j < kOpusScaleNumBins24kHz20ms[i]; ++j) { // Band size.
114 const float v = x[2 * k] * y[2 * k] + x[2 * k + 1] * y[2 * k + 1];
115 const float tmp = weights_[k] * v;
116 cross_corr[i] += v - tmp;
117 cross_corr[i + 1] += tmp;
118 k++;
119 }
120 }
121 cross_corr[0] *= 2.f; // The first band only gets half contribution.
122 RTC_DCHECK_EQ(k, kFrameSize20ms24kHz / 2); // Nyquist coefficient never used.
123 }
124
ComputeSmoothedLogMagnitudeSpectrum(rtc::ArrayView<const float> bands_energy,rtc::ArrayView<float,kNumBands> log_bands_energy)125 void ComputeSmoothedLogMagnitudeSpectrum(
126 rtc::ArrayView<const float> bands_energy,
127 rtc::ArrayView<float, kNumBands> log_bands_energy) {
128 RTC_DCHECK_LE(bands_energy.size(), kNumBands);
129 constexpr float kOneByHundred = 1e-2f;
130 constexpr float kLogOneByHundred = -2.f;
131 // Init.
132 float log_max = kLogOneByHundred;
133 float follow = kLogOneByHundred;
134 const auto smooth = [&log_max, &follow](float x) {
135 x = std::max(log_max - 7.f, std::max(follow - 1.5f, x));
136 log_max = std::max(log_max, x);
137 follow = std::max(follow - 1.5f, x);
138 return x;
139 };
140 // Smoothing over the bands for which the band energy is defined.
141 for (int i = 0; rtc::SafeLt(i, bands_energy.size()); ++i) {
142 log_bands_energy[i] = smooth(std::log10(kOneByHundred + bands_energy[i]));
143 }
144 // Smoothing over the remaining bands (zero energy).
145 for (int i = bands_energy.size(); i < kNumBands; ++i) {
146 log_bands_energy[i] = smooth(kLogOneByHundred);
147 }
148 }
149
ComputeDctTable()150 std::array<float, kNumBands * kNumBands> ComputeDctTable() {
151 std::array<float, kNumBands * kNumBands> dct_table;
152 const double k = std::sqrt(0.5);
153 for (int i = 0; i < kNumBands; ++i) {
154 for (int j = 0; j < kNumBands; ++j)
155 dct_table[i * kNumBands + j] = std::cos((i + 0.5) * j * kPi / kNumBands);
156 dct_table[i * kNumBands] *= k;
157 }
158 return dct_table;
159 }
160
ComputeDct(rtc::ArrayView<const float> in,rtc::ArrayView<const float,kNumBands * kNumBands> dct_table,rtc::ArrayView<float> out)161 void ComputeDct(rtc::ArrayView<const float> in,
162 rtc::ArrayView<const float, kNumBands * kNumBands> dct_table,
163 rtc::ArrayView<float> out) {
164 // DCT scaling factor - i.e., sqrt(2 / kNumBands).
165 constexpr float kDctScalingFactor = 0.301511345f;
166 constexpr float kDctScalingFactorError =
167 kDctScalingFactor * kDctScalingFactor -
168 2.f / static_cast<float>(kNumBands);
169 static_assert(
170 (kDctScalingFactorError >= 0.f && kDctScalingFactorError < 1e-1f) ||
171 (kDctScalingFactorError < 0.f && kDctScalingFactorError > -1e-1f),
172 "kNumBands changed and kDctScalingFactor has not been updated.");
173 RTC_DCHECK_NE(in.data(), out.data()) << "In-place DCT is not supported.";
174 RTC_DCHECK_LE(in.size(), kNumBands);
175 RTC_DCHECK_LE(1, out.size());
176 RTC_DCHECK_LE(out.size(), in.size());
177 for (int i = 0; rtc::SafeLt(i, out.size()); ++i) {
178 out[i] = 0.f;
179 for (int j = 0; rtc::SafeLt(j, in.size()); ++j) {
180 out[i] += in[j] * dct_table[j * kNumBands + i];
181 }
182 // TODO(bugs.webrtc.org/10480): Scaling factor in the DCT table.
183 out[i] *= kDctScalingFactor;
184 }
185 }
186
187 } // namespace rnn_vad
188 } // namespace webrtc
189