xref: /aosp_15_r20/external/webrtc/modules/audio_processing/aec3/subtractor_unittest.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/subtractor.h"
12 
13 #include <algorithm>
14 #include <memory>
15 #include <numeric>
16 #include <string>
17 
18 #include "modules/audio_processing/aec3/aec_state.h"
19 #include "modules/audio_processing/aec3/render_delay_buffer.h"
20 #include "modules/audio_processing/test/echo_canceller_test_tools.h"
21 #include "modules/audio_processing/utility/cascaded_biquad_filter.h"
22 #include "rtc_base/random.h"
23 #include "rtc_base/strings/string_builder.h"
24 #include "test/gtest.h"
25 
26 namespace webrtc {
27 namespace {
28 
RunSubtractorTest(size_t num_render_channels,size_t num_capture_channels,int num_blocks_to_process,int delay_samples,int refined_filter_length_blocks,int coarse_filter_length_blocks,bool uncorrelated_inputs,const std::vector<int> & blocks_with_echo_path_changes)29 std::vector<float> RunSubtractorTest(
30     size_t num_render_channels,
31     size_t num_capture_channels,
32     int num_blocks_to_process,
33     int delay_samples,
34     int refined_filter_length_blocks,
35     int coarse_filter_length_blocks,
36     bool uncorrelated_inputs,
37     const std::vector<int>& blocks_with_echo_path_changes) {
38   ApmDataDumper data_dumper(42);
39   constexpr int kSampleRateHz = 48000;
40   constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
41   EchoCanceller3Config config;
42   config.filter.refined.length_blocks = refined_filter_length_blocks;
43   config.filter.coarse.length_blocks = coarse_filter_length_blocks;
44 
45   Subtractor subtractor(config, num_render_channels, num_capture_channels,
46                         &data_dumper, DetectOptimization());
47   absl::optional<DelayEstimate> delay_estimate;
48   Block x(kNumBands, num_render_channels);
49   Block y(/*num_bands=*/1, num_capture_channels);
50   std::array<float, kBlockSize> x_old;
51   std::vector<SubtractorOutput> output(num_capture_channels);
52   config.delay.default_delay = 1;
53   std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
54       RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels));
55   RenderSignalAnalyzer render_signal_analyzer(config);
56   Random random_generator(42U);
57   Aec3Fft fft;
58   std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels);
59   std::vector<std::array<float, kFftLengthBy2Plus1>> E2_refined(
60       num_capture_channels);
61   std::array<float, kFftLengthBy2Plus1> E2_coarse;
62   AecState aec_state(config, num_capture_channels);
63   x_old.fill(0.f);
64   for (auto& Y2_ch : Y2) {
65     Y2_ch.fill(0.f);
66   }
67   for (auto& E2_refined_ch : E2_refined) {
68     E2_refined_ch.fill(0.f);
69   }
70   E2_coarse.fill(0.f);
71 
72   std::vector<std::vector<std::unique_ptr<DelayBuffer<float>>>> delay_buffer(
73       num_capture_channels);
74   for (size_t capture_ch = 0; capture_ch < num_capture_channels; ++capture_ch) {
75     delay_buffer[capture_ch].resize(num_render_channels);
76     for (size_t render_ch = 0; render_ch < num_render_channels; ++render_ch) {
77       delay_buffer[capture_ch][render_ch] =
78           std::make_unique<DelayBuffer<float>>(delay_samples);
79     }
80   }
81 
82   // [B,A] = butter(2,100/8000,'high')
83   constexpr CascadedBiQuadFilter::BiQuadCoefficients
84       kHighPassFilterCoefficients = {{0.97261f, -1.94523f, 0.97261f},
85                                      {-1.94448f, 0.94598f}};
86   std::vector<std::unique_ptr<CascadedBiQuadFilter>> x_hp_filter(
87       num_render_channels);
88   for (size_t ch = 0; ch < num_render_channels; ++ch) {
89     x_hp_filter[ch] =
90         std::make_unique<CascadedBiQuadFilter>(kHighPassFilterCoefficients, 1);
91   }
92   std::vector<std::unique_ptr<CascadedBiQuadFilter>> y_hp_filter(
93       num_capture_channels);
94   for (size_t ch = 0; ch < num_capture_channels; ++ch) {
95     y_hp_filter[ch] =
96         std::make_unique<CascadedBiQuadFilter>(kHighPassFilterCoefficients, 1);
97   }
98 
99   for (int k = 0; k < num_blocks_to_process; ++k) {
100     for (size_t render_ch = 0; render_ch < num_render_channels; ++render_ch) {
101       RandomizeSampleVector(&random_generator, x.View(/*band=*/0, render_ch));
102     }
103     if (uncorrelated_inputs) {
104       for (size_t capture_ch = 0; capture_ch < num_capture_channels;
105            ++capture_ch) {
106         RandomizeSampleVector(&random_generator,
107                               y.View(/*band=*/0, capture_ch));
108       }
109     } else {
110       for (size_t capture_ch = 0; capture_ch < num_capture_channels;
111            ++capture_ch) {
112         rtc::ArrayView<float> y_view = y.View(/*band=*/0, capture_ch);
113         for (size_t render_ch = 0; render_ch < num_render_channels;
114              ++render_ch) {
115           std::array<float, kBlockSize> y_channel;
116           delay_buffer[capture_ch][render_ch]->Delay(
117               x.View(/*band=*/0, render_ch), y_channel);
118           for (size_t k = 0; k < kBlockSize; ++k) {
119             y_view[k] += y_channel[k] / num_render_channels;
120           }
121         }
122       }
123     }
124     for (size_t ch = 0; ch < num_render_channels; ++ch) {
125       x_hp_filter[ch]->Process(x.View(/*band=*/0, ch));
126     }
127     for (size_t ch = 0; ch < num_capture_channels; ++ch) {
128       y_hp_filter[ch]->Process(y.View(/*band=*/0, ch));
129     }
130 
131     render_delay_buffer->Insert(x);
132     if (k == 0) {
133       render_delay_buffer->Reset();
134     }
135     render_delay_buffer->PrepareCaptureProcessing();
136     render_signal_analyzer.Update(*render_delay_buffer->GetRenderBuffer(),
137                                   aec_state.MinDirectPathFilterDelay());
138 
139     // Handle echo path changes.
140     if (std::find(blocks_with_echo_path_changes.begin(),
141                   blocks_with_echo_path_changes.end(),
142                   k) != blocks_with_echo_path_changes.end()) {
143       subtractor.HandleEchoPathChange(EchoPathVariability(
144           true, EchoPathVariability::DelayAdjustment::kNewDetectedDelay,
145           false));
146     }
147     subtractor.Process(*render_delay_buffer->GetRenderBuffer(), y,
148                        render_signal_analyzer, aec_state, output);
149 
150     aec_state.HandleEchoPathChange(EchoPathVariability(
151         false, EchoPathVariability::DelayAdjustment::kNone, false));
152     aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(),
153                      subtractor.FilterImpulseResponses(),
154                      *render_delay_buffer->GetRenderBuffer(), E2_refined, Y2,
155                      output);
156   }
157 
158   std::vector<float> results(num_capture_channels);
159   for (size_t ch = 0; ch < num_capture_channels; ++ch) {
160     const float output_power = std::inner_product(
161         output[ch].e_refined.begin(), output[ch].e_refined.end(),
162         output[ch].e_refined.begin(), 0.f);
163     const float y_power =
164         std::inner_product(y.begin(/*band=*/0, ch), y.end(/*band=*/0, ch),
165                            y.begin(/*band=*/0, ch), 0.f);
166     if (y_power == 0.f) {
167       ADD_FAILURE();
168       results[ch] = -1.f;
169     }
170     results[ch] = output_power / y_power;
171   }
172   return results;
173 }
174 
ProduceDebugText(size_t num_render_channels,size_t num_capture_channels,size_t delay,int filter_length_blocks)175 std::string ProduceDebugText(size_t num_render_channels,
176                              size_t num_capture_channels,
177                              size_t delay,
178                              int filter_length_blocks) {
179   rtc::StringBuilder ss;
180   ss << "delay: " << delay << ", ";
181   ss << "filter_length_blocks:" << filter_length_blocks << ", ";
182   ss << "num_render_channels:" << num_render_channels << ", ";
183   ss << "num_capture_channels:" << num_capture_channels;
184   return ss.Release();
185 }
186 
187 }  // namespace
188 
189 #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
190 
191 // Verifies that the check for non data dumper works.
TEST(SubtractorDeathTest,NullDataDumper)192 TEST(SubtractorDeathTest, NullDataDumper) {
193   EXPECT_DEATH(
194       Subtractor(EchoCanceller3Config(), 1, 1, nullptr, DetectOptimization()),
195       "");
196 }
197 
198 #endif
199 
200 // Verifies that the subtractor is able to converge on correlated data.
TEST(Subtractor,Convergence)201 TEST(Subtractor, Convergence) {
202   std::vector<int> blocks_with_echo_path_changes;
203   for (size_t filter_length_blocks : {12, 20, 30}) {
204     for (size_t delay_samples : {0, 64, 150, 200, 301}) {
205       SCOPED_TRACE(ProduceDebugText(1, 1, delay_samples, filter_length_blocks));
206       std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
207           1, 1, 2500, delay_samples, filter_length_blocks, filter_length_blocks,
208           false, blocks_with_echo_path_changes);
209 
210       for (float echo_to_nearend_power : echo_to_nearend_powers) {
211         EXPECT_GT(0.1f, echo_to_nearend_power);
212       }
213     }
214   }
215 }
216 
217 // Verifies that the subtractor is able to handle the case when the refined
218 // filter is longer than the coarse filter.
TEST(Subtractor,RefinedFilterLongerThanCoarseFilter)219 TEST(Subtractor, RefinedFilterLongerThanCoarseFilter) {
220   std::vector<int> blocks_with_echo_path_changes;
221   std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
222       1, 1, 400, 64, 20, 15, false, blocks_with_echo_path_changes);
223   for (float echo_to_nearend_power : echo_to_nearend_powers) {
224     EXPECT_GT(0.5f, echo_to_nearend_power);
225   }
226 }
227 
228 // Verifies that the subtractor is able to handle the case when the coarse
229 // filter is longer than the refined filter.
TEST(Subtractor,CoarseFilterLongerThanRefinedFilter)230 TEST(Subtractor, CoarseFilterLongerThanRefinedFilter) {
231   std::vector<int> blocks_with_echo_path_changes;
232   std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
233       1, 1, 400, 64, 15, 20, false, blocks_with_echo_path_changes);
234   for (float echo_to_nearend_power : echo_to_nearend_powers) {
235     EXPECT_GT(0.5f, echo_to_nearend_power);
236   }
237 }
238 
239 // Verifies that the subtractor does not converge on uncorrelated signals.
TEST(Subtractor,NonConvergenceOnUncorrelatedSignals)240 TEST(Subtractor, NonConvergenceOnUncorrelatedSignals) {
241   std::vector<int> blocks_with_echo_path_changes;
242   for (size_t filter_length_blocks : {12, 20, 30}) {
243     for (size_t delay_samples : {0, 64, 150, 200, 301}) {
244       SCOPED_TRACE(ProduceDebugText(1, 1, delay_samples, filter_length_blocks));
245 
246       std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
247           1, 1, 3000, delay_samples, filter_length_blocks, filter_length_blocks,
248           true, blocks_with_echo_path_changes);
249       for (float echo_to_nearend_power : echo_to_nearend_powers) {
250         EXPECT_NEAR(1.f, echo_to_nearend_power, 0.1);
251       }
252     }
253   }
254 }
255 
256 class SubtractorMultiChannelUpToEightRender
257     : public ::testing::Test,
258       public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
259 
260 #if defined(NDEBUG)
261 INSTANTIATE_TEST_SUITE_P(NonDebugMultiChannel,
262                          SubtractorMultiChannelUpToEightRender,
263                          ::testing::Combine(::testing::Values(1, 2, 8),
264                                             ::testing::Values(1, 2, 4)));
265 #else
266 INSTANTIATE_TEST_SUITE_P(DebugMultiChannel,
267                          SubtractorMultiChannelUpToEightRender,
268                          ::testing::Combine(::testing::Values(1, 2),
269                                             ::testing::Values(1, 2)));
270 #endif
271 
272 // Verifies that the subtractor is able to converge on correlated data.
TEST_P(SubtractorMultiChannelUpToEightRender,Convergence)273 TEST_P(SubtractorMultiChannelUpToEightRender, Convergence) {
274   const size_t num_render_channels = std::get<0>(GetParam());
275   const size_t num_capture_channels = std::get<1>(GetParam());
276 
277   std::vector<int> blocks_with_echo_path_changes;
278   size_t num_blocks_to_process = 2500 * num_render_channels;
279   std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
280       num_render_channels, num_capture_channels, num_blocks_to_process, 64, 20,
281       20, false, blocks_with_echo_path_changes);
282 
283   for (float echo_to_nearend_power : echo_to_nearend_powers) {
284     EXPECT_GT(0.1f, echo_to_nearend_power);
285   }
286 }
287 
288 class SubtractorMultiChannelUpToFourRender
289     : public ::testing::Test,
290       public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
291 
292 #if defined(NDEBUG)
293 INSTANTIATE_TEST_SUITE_P(NonDebugMultiChannel,
294                          SubtractorMultiChannelUpToFourRender,
295                          ::testing::Combine(::testing::Values(1, 2, 4),
296                                             ::testing::Values(1, 2, 4)));
297 #else
298 INSTANTIATE_TEST_SUITE_P(DebugMultiChannel,
299                          SubtractorMultiChannelUpToFourRender,
300                          ::testing::Combine(::testing::Values(1, 2),
301                                             ::testing::Values(1, 2)));
302 #endif
303 
304 // Verifies that the subtractor does not converge on uncorrelated signals.
TEST_P(SubtractorMultiChannelUpToFourRender,NonConvergenceOnUncorrelatedSignals)305 TEST_P(SubtractorMultiChannelUpToFourRender,
306        NonConvergenceOnUncorrelatedSignals) {
307   const size_t num_render_channels = std::get<0>(GetParam());
308   const size_t num_capture_channels = std::get<1>(GetParam());
309 
310   std::vector<int> blocks_with_echo_path_changes;
311   size_t num_blocks_to_process = 5000 * num_render_channels;
312   std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
313       num_render_channels, num_capture_channels, num_blocks_to_process, 64, 20,
314       20, true, blocks_with_echo_path_changes);
315   for (float echo_to_nearend_power : echo_to_nearend_powers) {
316     EXPECT_LT(.8f, echo_to_nearend_power);
317     EXPECT_NEAR(1.f, echo_to_nearend_power, 0.25f);
318   }
319 }
320 }  // namespace webrtc
321