1 /*
2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/aec3/subtractor.h"
12
13 #include <algorithm>
14 #include <memory>
15 #include <numeric>
16 #include <string>
17
18 #include "modules/audio_processing/aec3/aec_state.h"
19 #include "modules/audio_processing/aec3/render_delay_buffer.h"
20 #include "modules/audio_processing/test/echo_canceller_test_tools.h"
21 #include "modules/audio_processing/utility/cascaded_biquad_filter.h"
22 #include "rtc_base/random.h"
23 #include "rtc_base/strings/string_builder.h"
24 #include "test/gtest.h"
25
26 namespace webrtc {
27 namespace {
28
RunSubtractorTest(size_t num_render_channels,size_t num_capture_channels,int num_blocks_to_process,int delay_samples,int refined_filter_length_blocks,int coarse_filter_length_blocks,bool uncorrelated_inputs,const std::vector<int> & blocks_with_echo_path_changes)29 std::vector<float> RunSubtractorTest(
30 size_t num_render_channels,
31 size_t num_capture_channels,
32 int num_blocks_to_process,
33 int delay_samples,
34 int refined_filter_length_blocks,
35 int coarse_filter_length_blocks,
36 bool uncorrelated_inputs,
37 const std::vector<int>& blocks_with_echo_path_changes) {
38 ApmDataDumper data_dumper(42);
39 constexpr int kSampleRateHz = 48000;
40 constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
41 EchoCanceller3Config config;
42 config.filter.refined.length_blocks = refined_filter_length_blocks;
43 config.filter.coarse.length_blocks = coarse_filter_length_blocks;
44
45 Subtractor subtractor(config, num_render_channels, num_capture_channels,
46 &data_dumper, DetectOptimization());
47 absl::optional<DelayEstimate> delay_estimate;
48 Block x(kNumBands, num_render_channels);
49 Block y(/*num_bands=*/1, num_capture_channels);
50 std::array<float, kBlockSize> x_old;
51 std::vector<SubtractorOutput> output(num_capture_channels);
52 config.delay.default_delay = 1;
53 std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
54 RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels));
55 RenderSignalAnalyzer render_signal_analyzer(config);
56 Random random_generator(42U);
57 Aec3Fft fft;
58 std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels);
59 std::vector<std::array<float, kFftLengthBy2Plus1>> E2_refined(
60 num_capture_channels);
61 std::array<float, kFftLengthBy2Plus1> E2_coarse;
62 AecState aec_state(config, num_capture_channels);
63 x_old.fill(0.f);
64 for (auto& Y2_ch : Y2) {
65 Y2_ch.fill(0.f);
66 }
67 for (auto& E2_refined_ch : E2_refined) {
68 E2_refined_ch.fill(0.f);
69 }
70 E2_coarse.fill(0.f);
71
72 std::vector<std::vector<std::unique_ptr<DelayBuffer<float>>>> delay_buffer(
73 num_capture_channels);
74 for (size_t capture_ch = 0; capture_ch < num_capture_channels; ++capture_ch) {
75 delay_buffer[capture_ch].resize(num_render_channels);
76 for (size_t render_ch = 0; render_ch < num_render_channels; ++render_ch) {
77 delay_buffer[capture_ch][render_ch] =
78 std::make_unique<DelayBuffer<float>>(delay_samples);
79 }
80 }
81
82 // [B,A] = butter(2,100/8000,'high')
83 constexpr CascadedBiQuadFilter::BiQuadCoefficients
84 kHighPassFilterCoefficients = {{0.97261f, -1.94523f, 0.97261f},
85 {-1.94448f, 0.94598f}};
86 std::vector<std::unique_ptr<CascadedBiQuadFilter>> x_hp_filter(
87 num_render_channels);
88 for (size_t ch = 0; ch < num_render_channels; ++ch) {
89 x_hp_filter[ch] =
90 std::make_unique<CascadedBiQuadFilter>(kHighPassFilterCoefficients, 1);
91 }
92 std::vector<std::unique_ptr<CascadedBiQuadFilter>> y_hp_filter(
93 num_capture_channels);
94 for (size_t ch = 0; ch < num_capture_channels; ++ch) {
95 y_hp_filter[ch] =
96 std::make_unique<CascadedBiQuadFilter>(kHighPassFilterCoefficients, 1);
97 }
98
99 for (int k = 0; k < num_blocks_to_process; ++k) {
100 for (size_t render_ch = 0; render_ch < num_render_channels; ++render_ch) {
101 RandomizeSampleVector(&random_generator, x.View(/*band=*/0, render_ch));
102 }
103 if (uncorrelated_inputs) {
104 for (size_t capture_ch = 0; capture_ch < num_capture_channels;
105 ++capture_ch) {
106 RandomizeSampleVector(&random_generator,
107 y.View(/*band=*/0, capture_ch));
108 }
109 } else {
110 for (size_t capture_ch = 0; capture_ch < num_capture_channels;
111 ++capture_ch) {
112 rtc::ArrayView<float> y_view = y.View(/*band=*/0, capture_ch);
113 for (size_t render_ch = 0; render_ch < num_render_channels;
114 ++render_ch) {
115 std::array<float, kBlockSize> y_channel;
116 delay_buffer[capture_ch][render_ch]->Delay(
117 x.View(/*band=*/0, render_ch), y_channel);
118 for (size_t k = 0; k < kBlockSize; ++k) {
119 y_view[k] += y_channel[k] / num_render_channels;
120 }
121 }
122 }
123 }
124 for (size_t ch = 0; ch < num_render_channels; ++ch) {
125 x_hp_filter[ch]->Process(x.View(/*band=*/0, ch));
126 }
127 for (size_t ch = 0; ch < num_capture_channels; ++ch) {
128 y_hp_filter[ch]->Process(y.View(/*band=*/0, ch));
129 }
130
131 render_delay_buffer->Insert(x);
132 if (k == 0) {
133 render_delay_buffer->Reset();
134 }
135 render_delay_buffer->PrepareCaptureProcessing();
136 render_signal_analyzer.Update(*render_delay_buffer->GetRenderBuffer(),
137 aec_state.MinDirectPathFilterDelay());
138
139 // Handle echo path changes.
140 if (std::find(blocks_with_echo_path_changes.begin(),
141 blocks_with_echo_path_changes.end(),
142 k) != blocks_with_echo_path_changes.end()) {
143 subtractor.HandleEchoPathChange(EchoPathVariability(
144 true, EchoPathVariability::DelayAdjustment::kNewDetectedDelay,
145 false));
146 }
147 subtractor.Process(*render_delay_buffer->GetRenderBuffer(), y,
148 render_signal_analyzer, aec_state, output);
149
150 aec_state.HandleEchoPathChange(EchoPathVariability(
151 false, EchoPathVariability::DelayAdjustment::kNone, false));
152 aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(),
153 subtractor.FilterImpulseResponses(),
154 *render_delay_buffer->GetRenderBuffer(), E2_refined, Y2,
155 output);
156 }
157
158 std::vector<float> results(num_capture_channels);
159 for (size_t ch = 0; ch < num_capture_channels; ++ch) {
160 const float output_power = std::inner_product(
161 output[ch].e_refined.begin(), output[ch].e_refined.end(),
162 output[ch].e_refined.begin(), 0.f);
163 const float y_power =
164 std::inner_product(y.begin(/*band=*/0, ch), y.end(/*band=*/0, ch),
165 y.begin(/*band=*/0, ch), 0.f);
166 if (y_power == 0.f) {
167 ADD_FAILURE();
168 results[ch] = -1.f;
169 }
170 results[ch] = output_power / y_power;
171 }
172 return results;
173 }
174
ProduceDebugText(size_t num_render_channels,size_t num_capture_channels,size_t delay,int filter_length_blocks)175 std::string ProduceDebugText(size_t num_render_channels,
176 size_t num_capture_channels,
177 size_t delay,
178 int filter_length_blocks) {
179 rtc::StringBuilder ss;
180 ss << "delay: " << delay << ", ";
181 ss << "filter_length_blocks:" << filter_length_blocks << ", ";
182 ss << "num_render_channels:" << num_render_channels << ", ";
183 ss << "num_capture_channels:" << num_capture_channels;
184 return ss.Release();
185 }
186
187 } // namespace
188
189 #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
190
191 // Verifies that the check for non data dumper works.
TEST(SubtractorDeathTest,NullDataDumper)192 TEST(SubtractorDeathTest, NullDataDumper) {
193 EXPECT_DEATH(
194 Subtractor(EchoCanceller3Config(), 1, 1, nullptr, DetectOptimization()),
195 "");
196 }
197
198 #endif
199
200 // Verifies that the subtractor is able to converge on correlated data.
TEST(Subtractor,Convergence)201 TEST(Subtractor, Convergence) {
202 std::vector<int> blocks_with_echo_path_changes;
203 for (size_t filter_length_blocks : {12, 20, 30}) {
204 for (size_t delay_samples : {0, 64, 150, 200, 301}) {
205 SCOPED_TRACE(ProduceDebugText(1, 1, delay_samples, filter_length_blocks));
206 std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
207 1, 1, 2500, delay_samples, filter_length_blocks, filter_length_blocks,
208 false, blocks_with_echo_path_changes);
209
210 for (float echo_to_nearend_power : echo_to_nearend_powers) {
211 EXPECT_GT(0.1f, echo_to_nearend_power);
212 }
213 }
214 }
215 }
216
217 // Verifies that the subtractor is able to handle the case when the refined
218 // filter is longer than the coarse filter.
TEST(Subtractor,RefinedFilterLongerThanCoarseFilter)219 TEST(Subtractor, RefinedFilterLongerThanCoarseFilter) {
220 std::vector<int> blocks_with_echo_path_changes;
221 std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
222 1, 1, 400, 64, 20, 15, false, blocks_with_echo_path_changes);
223 for (float echo_to_nearend_power : echo_to_nearend_powers) {
224 EXPECT_GT(0.5f, echo_to_nearend_power);
225 }
226 }
227
228 // Verifies that the subtractor is able to handle the case when the coarse
229 // filter is longer than the refined filter.
TEST(Subtractor,CoarseFilterLongerThanRefinedFilter)230 TEST(Subtractor, CoarseFilterLongerThanRefinedFilter) {
231 std::vector<int> blocks_with_echo_path_changes;
232 std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
233 1, 1, 400, 64, 15, 20, false, blocks_with_echo_path_changes);
234 for (float echo_to_nearend_power : echo_to_nearend_powers) {
235 EXPECT_GT(0.5f, echo_to_nearend_power);
236 }
237 }
238
239 // Verifies that the subtractor does not converge on uncorrelated signals.
TEST(Subtractor,NonConvergenceOnUncorrelatedSignals)240 TEST(Subtractor, NonConvergenceOnUncorrelatedSignals) {
241 std::vector<int> blocks_with_echo_path_changes;
242 for (size_t filter_length_blocks : {12, 20, 30}) {
243 for (size_t delay_samples : {0, 64, 150, 200, 301}) {
244 SCOPED_TRACE(ProduceDebugText(1, 1, delay_samples, filter_length_blocks));
245
246 std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
247 1, 1, 3000, delay_samples, filter_length_blocks, filter_length_blocks,
248 true, blocks_with_echo_path_changes);
249 for (float echo_to_nearend_power : echo_to_nearend_powers) {
250 EXPECT_NEAR(1.f, echo_to_nearend_power, 0.1);
251 }
252 }
253 }
254 }
255
256 class SubtractorMultiChannelUpToEightRender
257 : public ::testing::Test,
258 public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
259
260 #if defined(NDEBUG)
261 INSTANTIATE_TEST_SUITE_P(NonDebugMultiChannel,
262 SubtractorMultiChannelUpToEightRender,
263 ::testing::Combine(::testing::Values(1, 2, 8),
264 ::testing::Values(1, 2, 4)));
265 #else
266 INSTANTIATE_TEST_SUITE_P(DebugMultiChannel,
267 SubtractorMultiChannelUpToEightRender,
268 ::testing::Combine(::testing::Values(1, 2),
269 ::testing::Values(1, 2)));
270 #endif
271
272 // Verifies that the subtractor is able to converge on correlated data.
TEST_P(SubtractorMultiChannelUpToEightRender,Convergence)273 TEST_P(SubtractorMultiChannelUpToEightRender, Convergence) {
274 const size_t num_render_channels = std::get<0>(GetParam());
275 const size_t num_capture_channels = std::get<1>(GetParam());
276
277 std::vector<int> blocks_with_echo_path_changes;
278 size_t num_blocks_to_process = 2500 * num_render_channels;
279 std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
280 num_render_channels, num_capture_channels, num_blocks_to_process, 64, 20,
281 20, false, blocks_with_echo_path_changes);
282
283 for (float echo_to_nearend_power : echo_to_nearend_powers) {
284 EXPECT_GT(0.1f, echo_to_nearend_power);
285 }
286 }
287
288 class SubtractorMultiChannelUpToFourRender
289 : public ::testing::Test,
290 public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
291
292 #if defined(NDEBUG)
293 INSTANTIATE_TEST_SUITE_P(NonDebugMultiChannel,
294 SubtractorMultiChannelUpToFourRender,
295 ::testing::Combine(::testing::Values(1, 2, 4),
296 ::testing::Values(1, 2, 4)));
297 #else
298 INSTANTIATE_TEST_SUITE_P(DebugMultiChannel,
299 SubtractorMultiChannelUpToFourRender,
300 ::testing::Combine(::testing::Values(1, 2),
301 ::testing::Values(1, 2)));
302 #endif
303
304 // Verifies that the subtractor does not converge on uncorrelated signals.
TEST_P(SubtractorMultiChannelUpToFourRender,NonConvergenceOnUncorrelatedSignals)305 TEST_P(SubtractorMultiChannelUpToFourRender,
306 NonConvergenceOnUncorrelatedSignals) {
307 const size_t num_render_channels = std::get<0>(GetParam());
308 const size_t num_capture_channels = std::get<1>(GetParam());
309
310 std::vector<int> blocks_with_echo_path_changes;
311 size_t num_blocks_to_process = 5000 * num_render_channels;
312 std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
313 num_render_channels, num_capture_channels, num_blocks_to_process, 64, 20,
314 20, true, blocks_with_echo_path_changes);
315 for (float echo_to_nearend_power : echo_to_nearend_powers) {
316 EXPECT_LT(.8f, echo_to_nearend_power);
317 EXPECT_NEAR(1.f, echo_to_nearend_power, 0.25f);
318 }
319 }
320 } // namespace webrtc
321