1 /*
2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10 #include "modules/audio_processing/aec3/matched_filter.h"
11
12 // Defines WEBRTC_ARCH_X86_FAMILY, used below.
13 #include "rtc_base/system/arch.h"
14
15 #if defined(WEBRTC_HAS_NEON)
16 #include <arm_neon.h>
17 #endif
18 #if defined(WEBRTC_ARCH_X86_FAMILY)
19 #include <emmintrin.h>
20 #endif
21 #include <algorithm>
22 #include <cstddef>
23 #include <initializer_list>
24 #include <iterator>
25 #include <numeric>
26
27 #include "absl/types/optional.h"
28 #include "api/array_view.h"
29 #include "modules/audio_processing/aec3/downsampled_render_buffer.h"
30 #include "modules/audio_processing/logging/apm_data_dumper.h"
31 #include "rtc_base/checks.h"
32 #include "rtc_base/experiments/field_trial_parser.h"
33 #include "rtc_base/logging.h"
34 #include "system_wrappers/include/field_trial.h"
35
36 namespace {
37
38 // Subsample rate used for computing the accumulated error.
39 // The implementation of some core functions depends on this constant being
40 // equal to 4.
41 constexpr int kAccumulatedErrorSubSampleRate = 4;
42
UpdateAccumulatedError(const rtc::ArrayView<const float> instantaneous_accumulated_error,const rtc::ArrayView<float> accumulated_error,float one_over_error_sum_anchor)43 void UpdateAccumulatedError(
44 const rtc::ArrayView<const float> instantaneous_accumulated_error,
45 const rtc::ArrayView<float> accumulated_error,
46 float one_over_error_sum_anchor) {
47 for (size_t k = 0; k < instantaneous_accumulated_error.size(); ++k) {
48 float error_norm =
49 instantaneous_accumulated_error[k] * one_over_error_sum_anchor;
50 if (error_norm < accumulated_error[k]) {
51 accumulated_error[k] = error_norm;
52 } else {
53 accumulated_error[k] += 0.01f * (error_norm - accumulated_error[k]);
54 }
55 }
56 }
57
ComputePreEchoLag(const webrtc::MatchedFilter::PreEchoConfiguration & pre_echo_configuration,const rtc::ArrayView<const float> accumulated_error,size_t lag,size_t alignment_shift_winner)58 size_t ComputePreEchoLag(
59 const webrtc::MatchedFilter::PreEchoConfiguration& pre_echo_configuration,
60 const rtc::ArrayView<const float> accumulated_error,
61 size_t lag,
62 size_t alignment_shift_winner) {
63 RTC_DCHECK_GE(lag, alignment_shift_winner);
64 size_t pre_echo_lag_estimate = lag - alignment_shift_winner;
65 size_t maximum_pre_echo_lag =
66 std::min(pre_echo_lag_estimate / kAccumulatedErrorSubSampleRate,
67 accumulated_error.size());
68 switch (pre_echo_configuration.mode) {
69 case 0:
70 // Mode 0: Pre echo lag is defined as the first coefficient with an error
71 // lower than a threshold with a certain decrease slope.
72 for (size_t k = 1; k < maximum_pre_echo_lag; ++k) {
73 if (accumulated_error[k] <
74 pre_echo_configuration.threshold * accumulated_error[k - 1] &&
75 accumulated_error[k] < pre_echo_configuration.threshold) {
76 pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1;
77 break;
78 }
79 }
80 break;
81 case 1:
82 // Mode 1: Pre echo lag is defined as the first coefficient with an error
83 // lower than a certain threshold.
84 for (size_t k = 0; k < maximum_pre_echo_lag; ++k) {
85 if (accumulated_error[k] < pre_echo_configuration.threshold) {
86 pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1;
87 break;
88 }
89 }
90 break;
91 case 2:
92 // Mode 2: Pre echo lag is defined as the closest coefficient to the lag
93 // with an error lower than a certain threshold.
94 for (int k = static_cast<int>(maximum_pre_echo_lag) - 1; k >= 0; --k) {
95 if (accumulated_error[k] > pre_echo_configuration.threshold) {
96 break;
97 }
98 pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1;
99 }
100 break;
101 default:
102 RTC_DCHECK_NOTREACHED();
103 break;
104 }
105 return pre_echo_lag_estimate + alignment_shift_winner;
106 }
107
FetchPreEchoConfiguration()108 webrtc::MatchedFilter::PreEchoConfiguration FetchPreEchoConfiguration() {
109 float threshold = 0.5f;
110 int mode = 0;
111 const std::string pre_echo_configuration_field_trial =
112 webrtc::field_trial::FindFullName("WebRTC-Aec3PreEchoConfiguration");
113 webrtc::FieldTrialParameter<double> threshold_field_trial_parameter(
114 /*key=*/"threshold", /*default_value=*/threshold);
115 webrtc::FieldTrialParameter<int> mode_field_trial_parameter(
116 /*key=*/"mode", /*default_value=*/mode);
117 webrtc::ParseFieldTrial(
118 {&threshold_field_trial_parameter, &mode_field_trial_parameter},
119 pre_echo_configuration_field_trial);
120 float threshold_read =
121 static_cast<float>(threshold_field_trial_parameter.Get());
122 int mode_read = mode_field_trial_parameter.Get();
123 if (threshold_read < 1.0f && threshold_read > 0.0f) {
124 threshold = threshold_read;
125 } else {
126 RTC_LOG(LS_ERROR)
127 << "AEC3: Pre echo configuration: wrong input, threshold = "
128 << threshold_read << ".";
129 }
130 if (mode_read >= 0 && mode_read <= 3) {
131 mode = mode_read;
132 } else {
133 RTC_LOG(LS_ERROR) << "AEC3: Pre echo configuration: wrong input, mode = "
134 << mode_read << ".";
135 }
136 RTC_LOG(LS_INFO) << "AEC3: Pre echo configuration: threshold = " << threshold
137 << ", mode = " << mode << ".";
138 return {.threshold = threshold, .mode = mode};
139 }
140
141 } // namespace
142
143 namespace webrtc {
144 namespace aec3 {
145
146 #if defined(WEBRTC_HAS_NEON)
147
SumAllElements(float32x4_t elements)148 inline float SumAllElements(float32x4_t elements) {
149 float32x2_t sum = vpadd_f32(vget_low_f32(elements), vget_high_f32(elements));
150 sum = vpadd_f32(sum, sum);
151 return vget_lane_f32(sum, 0);
152 }
153
MatchedFilterCoreWithAccumulatedError_NEON(size_t x_start_index,float x2_sum_threshold,float smoothing,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum,rtc::ArrayView<float> accumulated_error,rtc::ArrayView<float> scratch_memory)154 void MatchedFilterCoreWithAccumulatedError_NEON(
155 size_t x_start_index,
156 float x2_sum_threshold,
157 float smoothing,
158 rtc::ArrayView<const float> x,
159 rtc::ArrayView<const float> y,
160 rtc::ArrayView<float> h,
161 bool* filters_updated,
162 float* error_sum,
163 rtc::ArrayView<float> accumulated_error,
164 rtc::ArrayView<float> scratch_memory) {
165 const int h_size = static_cast<int>(h.size());
166 const int x_size = static_cast<int>(x.size());
167 RTC_DCHECK_EQ(0, h_size % 4);
168 std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
169 // Process for all samples in the sub-block.
170 for (size_t i = 0; i < y.size(); ++i) {
171 // Apply the matched filter as filter * x, and compute x * x.
172 RTC_DCHECK_GT(x_size, x_start_index);
173 // Compute loop chunk sizes until, and after, the wraparound of the circular
174 // buffer for x.
175 const int chunk1 =
176 std::min(h_size, static_cast<int>(x_size - x_start_index));
177 if (chunk1 != h_size) {
178 const int chunk2 = h_size - chunk1;
179 std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin());
180 std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1);
181 }
182 const float* x_p =
183 chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
184 const float* h_p = &h[0];
185 float* accumulated_error_p = &accumulated_error[0];
186 // Initialize values for the accumulation.
187 float32x4_t x2_sum_128 = vdupq_n_f32(0);
188 float x2_sum = 0.f;
189 float s = 0;
190 // Perform 128 bit vector operations.
191 const int limit_by_4 = h_size >> 2;
192 for (int k = limit_by_4; k > 0;
193 --k, h_p += 4, x_p += 4, accumulated_error_p++) {
194 // Load the data into 128 bit vectors.
195 const float32x4_t x_k = vld1q_f32(x_p);
196 const float32x4_t h_k = vld1q_f32(h_p);
197 // Compute and accumulate x * x.
198 x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k);
199 // Compute x * h
200 float32x4_t hk_xk_128 = vmulq_f32(h_k, x_k);
201 s += SumAllElements(hk_xk_128);
202 const float e = s - y[i];
203 accumulated_error_p[0] += e * e;
204 }
205 // Combine the accumulated vector and scalar values.
206 x2_sum += SumAllElements(x2_sum_128);
207 // Compute the matched filter error.
208 float e = y[i] - s;
209 const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
210 (*error_sum) += e * e;
211 // Update the matched filter estimate in an NLMS manner.
212 if (x2_sum > x2_sum_threshold && !saturation) {
213 RTC_DCHECK_LT(0.f, x2_sum);
214 const float alpha = smoothing * e / x2_sum;
215 const float32x4_t alpha_128 = vmovq_n_f32(alpha);
216 // filter = filter + smoothing * (y - filter * x) * x / x * x.
217 float* h_p = &h[0];
218 x_p = chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
219 // Perform 128 bit vector operations.
220 const int limit_by_4 = h_size >> 2;
221 for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
222 // Load the data into 128 bit vectors.
223 float32x4_t h_k = vld1q_f32(h_p);
224 const float32x4_t x_k = vld1q_f32(x_p);
225 // Compute h = h + alpha * x.
226 h_k = vmlaq_f32(h_k, alpha_128, x_k);
227 // Store the result.
228 vst1q_f32(h_p, h_k);
229 }
230 *filters_updated = true;
231 }
232 x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
233 }
234 }
235
MatchedFilterCore_NEON(size_t x_start_index,float x2_sum_threshold,float smoothing,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum,bool compute_accumulated_error,rtc::ArrayView<float> accumulated_error,rtc::ArrayView<float> scratch_memory)236 void MatchedFilterCore_NEON(size_t x_start_index,
237 float x2_sum_threshold,
238 float smoothing,
239 rtc::ArrayView<const float> x,
240 rtc::ArrayView<const float> y,
241 rtc::ArrayView<float> h,
242 bool* filters_updated,
243 float* error_sum,
244 bool compute_accumulated_error,
245 rtc::ArrayView<float> accumulated_error,
246 rtc::ArrayView<float> scratch_memory) {
247 const int h_size = static_cast<int>(h.size());
248 const int x_size = static_cast<int>(x.size());
249 RTC_DCHECK_EQ(0, h_size % 4);
250
251 if (compute_accumulated_error) {
252 return MatchedFilterCoreWithAccumulatedError_NEON(
253 x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated,
254 error_sum, accumulated_error, scratch_memory);
255 }
256
257 // Process for all samples in the sub-block.
258 for (size_t i = 0; i < y.size(); ++i) {
259 // Apply the matched filter as filter * x, and compute x * x.
260
261 RTC_DCHECK_GT(x_size, x_start_index);
262 const float* x_p = &x[x_start_index];
263 const float* h_p = &h[0];
264
265 // Initialize values for the accumulation.
266 float32x4_t s_128 = vdupq_n_f32(0);
267 float32x4_t x2_sum_128 = vdupq_n_f32(0);
268 float x2_sum = 0.f;
269 float s = 0;
270
271 // Compute loop chunk sizes until, and after, the wraparound of the circular
272 // buffer for x.
273 const int chunk1 =
274 std::min(h_size, static_cast<int>(x_size - x_start_index));
275
276 // Perform the loop in two chunks.
277 const int chunk2 = h_size - chunk1;
278 for (int limit : {chunk1, chunk2}) {
279 // Perform 128 bit vector operations.
280 const int limit_by_4 = limit >> 2;
281 for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
282 // Load the data into 128 bit vectors.
283 const float32x4_t x_k = vld1q_f32(x_p);
284 const float32x4_t h_k = vld1q_f32(h_p);
285 // Compute and accumulate x * x and h * x.
286 x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k);
287 s_128 = vmlaq_f32(s_128, h_k, x_k);
288 }
289
290 // Perform non-vector operations for any remaining items.
291 for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
292 const float x_k = *x_p;
293 x2_sum += x_k * x_k;
294 s += *h_p * x_k;
295 }
296
297 x_p = &x[0];
298 }
299
300 // Combine the accumulated vector and scalar values.
301 s += SumAllElements(s_128);
302 x2_sum += SumAllElements(x2_sum_128);
303
304 // Compute the matched filter error.
305 float e = y[i] - s;
306 const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
307 (*error_sum) += e * e;
308
309 // Update the matched filter estimate in an NLMS manner.
310 if (x2_sum > x2_sum_threshold && !saturation) {
311 RTC_DCHECK_LT(0.f, x2_sum);
312 const float alpha = smoothing * e / x2_sum;
313 const float32x4_t alpha_128 = vmovq_n_f32(alpha);
314
315 // filter = filter + smoothing * (y - filter * x) * x / x * x.
316 float* h_p = &h[0];
317 x_p = &x[x_start_index];
318
319 // Perform the loop in two chunks.
320 for (int limit : {chunk1, chunk2}) {
321 // Perform 128 bit vector operations.
322 const int limit_by_4 = limit >> 2;
323 for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
324 // Load the data into 128 bit vectors.
325 float32x4_t h_k = vld1q_f32(h_p);
326 const float32x4_t x_k = vld1q_f32(x_p);
327 // Compute h = h + alpha * x.
328 h_k = vmlaq_f32(h_k, alpha_128, x_k);
329
330 // Store the result.
331 vst1q_f32(h_p, h_k);
332 }
333
334 // Perform non-vector operations for any remaining items.
335 for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
336 *h_p += alpha * *x_p;
337 }
338
339 x_p = &x[0];
340 }
341
342 *filters_updated = true;
343 }
344
345 x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
346 }
347 }
348
349 #endif
350
351 #if defined(WEBRTC_ARCH_X86_FAMILY)
352
MatchedFilterCore_AccumulatedError_SSE2(size_t x_start_index,float x2_sum_threshold,float smoothing,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum,rtc::ArrayView<float> accumulated_error,rtc::ArrayView<float> scratch_memory)353 void MatchedFilterCore_AccumulatedError_SSE2(
354 size_t x_start_index,
355 float x2_sum_threshold,
356 float smoothing,
357 rtc::ArrayView<const float> x,
358 rtc::ArrayView<const float> y,
359 rtc::ArrayView<float> h,
360 bool* filters_updated,
361 float* error_sum,
362 rtc::ArrayView<float> accumulated_error,
363 rtc::ArrayView<float> scratch_memory) {
364 const int h_size = static_cast<int>(h.size());
365 const int x_size = static_cast<int>(x.size());
366 RTC_DCHECK_EQ(0, h_size % 8);
367 std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
368 // Process for all samples in the sub-block.
369 for (size_t i = 0; i < y.size(); ++i) {
370 // Apply the matched filter as filter * x, and compute x * x.
371 RTC_DCHECK_GT(x_size, x_start_index);
372 const int chunk1 =
373 std::min(h_size, static_cast<int>(x_size - x_start_index));
374 if (chunk1 != h_size) {
375 const int chunk2 = h_size - chunk1;
376 std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin());
377 std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1);
378 }
379 const float* x_p =
380 chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
381 const float* h_p = &h[0];
382 float* a_p = &accumulated_error[0];
383 __m128 s_inst_128;
384 __m128 s_inst_128_4;
385 __m128 x2_sum_128 = _mm_set1_ps(0);
386 __m128 x2_sum_128_4 = _mm_set1_ps(0);
387 __m128 e_128;
388 float* const s_p = reinterpret_cast<float*>(&s_inst_128);
389 float* const s_4_p = reinterpret_cast<float*>(&s_inst_128_4);
390 float* const e_p = reinterpret_cast<float*>(&e_128);
391 float x2_sum = 0.0f;
392 float s_acum = 0;
393 // Perform 128 bit vector operations.
394 const int limit_by_8 = h_size >> 3;
395 for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8, a_p += 2) {
396 // Load the data into 128 bit vectors.
397 const __m128 x_k = _mm_loadu_ps(x_p);
398 const __m128 h_k = _mm_loadu_ps(h_p);
399 const __m128 x_k_4 = _mm_loadu_ps(x_p + 4);
400 const __m128 h_k_4 = _mm_loadu_ps(h_p + 4);
401 const __m128 xx = _mm_mul_ps(x_k, x_k);
402 const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4);
403 // Compute and accumulate x * x and h * x.
404 x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
405 x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4);
406 s_inst_128 = _mm_mul_ps(h_k, x_k);
407 s_inst_128_4 = _mm_mul_ps(h_k_4, x_k_4);
408 s_acum += s_p[0] + s_p[1] + s_p[2] + s_p[3];
409 e_p[0] = s_acum - y[i];
410 s_acum += s_4_p[0] + s_4_p[1] + s_4_p[2] + s_4_p[3];
411 e_p[1] = s_acum - y[i];
412 a_p[0] += e_p[0] * e_p[0];
413 a_p[1] += e_p[1] * e_p[1];
414 }
415 // Combine the accumulated vector and scalar values.
416 x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4);
417 float* v = reinterpret_cast<float*>(&x2_sum_128);
418 x2_sum += v[0] + v[1] + v[2] + v[3];
419 // Compute the matched filter error.
420 float e = y[i] - s_acum;
421 const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
422 (*error_sum) += e * e;
423 // Update the matched filter estimate in an NLMS manner.
424 if (x2_sum > x2_sum_threshold && !saturation) {
425 RTC_DCHECK_LT(0.f, x2_sum);
426 const float alpha = smoothing * e / x2_sum;
427 const __m128 alpha_128 = _mm_set1_ps(alpha);
428 // filter = filter + smoothing * (y - filter * x) * x / x * x.
429 float* h_p = &h[0];
430 const float* x_p =
431 chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
432 // Perform 128 bit vector operations.
433 const int limit_by_4 = h_size >> 2;
434 for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
435 // Load the data into 128 bit vectors.
436 __m128 h_k = _mm_loadu_ps(h_p);
437 const __m128 x_k = _mm_loadu_ps(x_p);
438 // Compute h = h + alpha * x.
439 const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
440 h_k = _mm_add_ps(h_k, alpha_x);
441 // Store the result.
442 _mm_storeu_ps(h_p, h_k);
443 }
444 *filters_updated = true;
445 }
446 x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
447 }
448 }
449
MatchedFilterCore_SSE2(size_t x_start_index,float x2_sum_threshold,float smoothing,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum,bool compute_accumulated_error,rtc::ArrayView<float> accumulated_error,rtc::ArrayView<float> scratch_memory)450 void MatchedFilterCore_SSE2(size_t x_start_index,
451 float x2_sum_threshold,
452 float smoothing,
453 rtc::ArrayView<const float> x,
454 rtc::ArrayView<const float> y,
455 rtc::ArrayView<float> h,
456 bool* filters_updated,
457 float* error_sum,
458 bool compute_accumulated_error,
459 rtc::ArrayView<float> accumulated_error,
460 rtc::ArrayView<float> scratch_memory) {
461 if (compute_accumulated_error) {
462 return MatchedFilterCore_AccumulatedError_SSE2(
463 x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated,
464 error_sum, accumulated_error, scratch_memory);
465 }
466 const int h_size = static_cast<int>(h.size());
467 const int x_size = static_cast<int>(x.size());
468 RTC_DCHECK_EQ(0, h_size % 4);
469 // Process for all samples in the sub-block.
470 for (size_t i = 0; i < y.size(); ++i) {
471 // Apply the matched filter as filter * x, and compute x * x.
472 RTC_DCHECK_GT(x_size, x_start_index);
473 const float* x_p = &x[x_start_index];
474 const float* h_p = &h[0];
475 // Initialize values for the accumulation.
476 __m128 s_128 = _mm_set1_ps(0);
477 __m128 s_128_4 = _mm_set1_ps(0);
478 __m128 x2_sum_128 = _mm_set1_ps(0);
479 __m128 x2_sum_128_4 = _mm_set1_ps(0);
480 float x2_sum = 0.f;
481 float s = 0;
482 // Compute loop chunk sizes until, and after, the wraparound of the circular
483 // buffer for x.
484 const int chunk1 =
485 std::min(h_size, static_cast<int>(x_size - x_start_index));
486 // Perform the loop in two chunks.
487 const int chunk2 = h_size - chunk1;
488 for (int limit : {chunk1, chunk2}) {
489 // Perform 128 bit vector operations.
490 const int limit_by_8 = limit >> 3;
491 for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) {
492 // Load the data into 128 bit vectors.
493 const __m128 x_k = _mm_loadu_ps(x_p);
494 const __m128 h_k = _mm_loadu_ps(h_p);
495 const __m128 x_k_4 = _mm_loadu_ps(x_p + 4);
496 const __m128 h_k_4 = _mm_loadu_ps(h_p + 4);
497 const __m128 xx = _mm_mul_ps(x_k, x_k);
498 const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4);
499 // Compute and accumulate x * x and h * x.
500 x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
501 x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4);
502 const __m128 hx = _mm_mul_ps(h_k, x_k);
503 const __m128 hx_4 = _mm_mul_ps(h_k_4, x_k_4);
504 s_128 = _mm_add_ps(s_128, hx);
505 s_128_4 = _mm_add_ps(s_128_4, hx_4);
506 }
507 // Perform non-vector operations for any remaining items.
508 for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) {
509 const float x_k = *x_p;
510 x2_sum += x_k * x_k;
511 s += *h_p * x_k;
512 }
513 x_p = &x[0];
514 }
515 // Combine the accumulated vector and scalar values.
516 x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4);
517 float* v = reinterpret_cast<float*>(&x2_sum_128);
518 x2_sum += v[0] + v[1] + v[2] + v[3];
519 s_128 = _mm_add_ps(s_128, s_128_4);
520 v = reinterpret_cast<float*>(&s_128);
521 s += v[0] + v[1] + v[2] + v[3];
522 // Compute the matched filter error.
523 float e = y[i] - s;
524 const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
525 (*error_sum) += e * e;
526 // Update the matched filter estimate in an NLMS manner.
527 if (x2_sum > x2_sum_threshold && !saturation) {
528 RTC_DCHECK_LT(0.f, x2_sum);
529 const float alpha = smoothing * e / x2_sum;
530 const __m128 alpha_128 = _mm_set1_ps(alpha);
531 // filter = filter + smoothing * (y - filter * x) * x / x * x.
532 float* h_p = &h[0];
533 x_p = &x[x_start_index];
534 // Perform the loop in two chunks.
535 for (int limit : {chunk1, chunk2}) {
536 // Perform 128 bit vector operations.
537 const int limit_by_4 = limit >> 2;
538 for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
539 // Load the data into 128 bit vectors.
540 __m128 h_k = _mm_loadu_ps(h_p);
541 const __m128 x_k = _mm_loadu_ps(x_p);
542
543 // Compute h = h + alpha * x.
544 const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
545 h_k = _mm_add_ps(h_k, alpha_x);
546 // Store the result.
547 _mm_storeu_ps(h_p, h_k);
548 }
549 // Perform non-vector operations for any remaining items.
550 for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
551 *h_p += alpha * *x_p;
552 }
553 x_p = &x[0];
554 }
555 *filters_updated = true;
556 }
557 x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
558 }
559 }
560 #endif
561
MatchedFilterCore(size_t x_start_index,float x2_sum_threshold,float smoothing,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum,bool compute_accumulated_error,rtc::ArrayView<float> accumulated_error)562 void MatchedFilterCore(size_t x_start_index,
563 float x2_sum_threshold,
564 float smoothing,
565 rtc::ArrayView<const float> x,
566 rtc::ArrayView<const float> y,
567 rtc::ArrayView<float> h,
568 bool* filters_updated,
569 float* error_sum,
570 bool compute_accumulated_error,
571 rtc::ArrayView<float> accumulated_error) {
572 if (compute_accumulated_error) {
573 std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
574 }
575
576 // Process for all samples in the sub-block.
577 for (size_t i = 0; i < y.size(); ++i) {
578 // Apply the matched filter as filter * x, and compute x * x.
579 float x2_sum = 0.f;
580 float s = 0;
581 size_t x_index = x_start_index;
582 if (compute_accumulated_error) {
583 for (size_t k = 0; k < h.size(); ++k) {
584 x2_sum += x[x_index] * x[x_index];
585 s += h[k] * x[x_index];
586 x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
587 if ((k + 1 & 0b11) == 0) {
588 int idx = k >> 2;
589 accumulated_error[idx] += (y[i] - s) * (y[i] - s);
590 }
591 }
592 } else {
593 for (size_t k = 0; k < h.size(); ++k) {
594 x2_sum += x[x_index] * x[x_index];
595 s += h[k] * x[x_index];
596 x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
597 }
598 }
599
600 // Compute the matched filter error.
601 float e = y[i] - s;
602 const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
603 (*error_sum) += e * e;
604
605 // Update the matched filter estimate in an NLMS manner.
606 if (x2_sum > x2_sum_threshold && !saturation) {
607 RTC_DCHECK_LT(0.f, x2_sum);
608 const float alpha = smoothing * e / x2_sum;
609
610 // filter = filter + smoothing * (y - filter * x) * x / x * x.
611 size_t x_index = x_start_index;
612 for (size_t k = 0; k < h.size(); ++k) {
613 h[k] += alpha * x[x_index];
614 x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
615 }
616 *filters_updated = true;
617 }
618
619 x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1;
620 }
621 }
622
MaxSquarePeakIndex(rtc::ArrayView<const float> h)623 size_t MaxSquarePeakIndex(rtc::ArrayView<const float> h) {
624 if (h.size() < 2) {
625 return 0;
626 }
627 float max_element1 = h[0] * h[0];
628 float max_element2 = h[1] * h[1];
629 size_t lag_estimate1 = 0;
630 size_t lag_estimate2 = 1;
631 const size_t last_index = h.size() - 1;
632 // Keeping track of even & odd max elements separately typically allows the
633 // compiler to produce more efficient code.
634 for (size_t k = 2; k < last_index; k += 2) {
635 float element1 = h[k] * h[k];
636 float element2 = h[k + 1] * h[k + 1];
637 if (element1 > max_element1) {
638 max_element1 = element1;
639 lag_estimate1 = k;
640 }
641 if (element2 > max_element2) {
642 max_element2 = element2;
643 lag_estimate2 = k + 1;
644 }
645 }
646 if (max_element2 > max_element1) {
647 max_element1 = max_element2;
648 lag_estimate1 = lag_estimate2;
649 }
650 // In case of odd h size, we have not yet checked the last element.
651 float last_element = h[last_index] * h[last_index];
652 if (last_element > max_element1) {
653 return last_index;
654 }
655 return lag_estimate1;
656 }
657
658 } // namespace aec3
659
MatchedFilter(ApmDataDumper * data_dumper,Aec3Optimization optimization,size_t sub_block_size,size_t window_size_sub_blocks,int num_matched_filters,size_t alignment_shift_sub_blocks,float excitation_limit,float smoothing_fast,float smoothing_slow,float matching_filter_threshold,bool detect_pre_echo)660 MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
661 Aec3Optimization optimization,
662 size_t sub_block_size,
663 size_t window_size_sub_blocks,
664 int num_matched_filters,
665 size_t alignment_shift_sub_blocks,
666 float excitation_limit,
667 float smoothing_fast,
668 float smoothing_slow,
669 float matching_filter_threshold,
670 bool detect_pre_echo)
671 : data_dumper_(data_dumper),
672 optimization_(optimization),
673 sub_block_size_(sub_block_size),
674 filter_intra_lag_shift_(alignment_shift_sub_blocks * sub_block_size_),
675 filters_(
676 num_matched_filters,
677 std::vector<float>(window_size_sub_blocks * sub_block_size_, 0.f)),
678 filters_offsets_(num_matched_filters, 0),
679 excitation_limit_(excitation_limit),
680 smoothing_fast_(smoothing_fast),
681 smoothing_slow_(smoothing_slow),
682 matching_filter_threshold_(matching_filter_threshold),
683 detect_pre_echo_(detect_pre_echo),
684 pre_echo_config_(FetchPreEchoConfiguration()) {
685 RTC_DCHECK(data_dumper);
686 RTC_DCHECK_LT(0, window_size_sub_blocks);
687 RTC_DCHECK((kBlockSize % sub_block_size) == 0);
688 RTC_DCHECK((sub_block_size % 4) == 0);
689 static_assert(kAccumulatedErrorSubSampleRate == 4);
690 if (detect_pre_echo_) {
691 accumulated_error_ = std::vector<std::vector<float>>(
692 num_matched_filters,
693 std::vector<float>(window_size_sub_blocks * sub_block_size_ /
694 kAccumulatedErrorSubSampleRate,
695 1.0f));
696
697 instantaneous_accumulated_error_ =
698 std::vector<float>(window_size_sub_blocks * sub_block_size_ /
699 kAccumulatedErrorSubSampleRate,
700 0.0f);
701 scratch_memory_ =
702 std::vector<float>(window_size_sub_blocks * sub_block_size_);
703 }
704 }
705
706 MatchedFilter::~MatchedFilter() = default;
707
Reset()708 void MatchedFilter::Reset() {
709 for (auto& f : filters_) {
710 std::fill(f.begin(), f.end(), 0.f);
711 }
712
713 for (auto& e : accumulated_error_) {
714 std::fill(e.begin(), e.end(), 1.0f);
715 }
716
717 winner_lag_ = absl::nullopt;
718 reported_lag_estimate_ = absl::nullopt;
719 }
720
Update(const DownsampledRenderBuffer & render_buffer,rtc::ArrayView<const float> capture,bool use_slow_smoothing)721 void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
722 rtc::ArrayView<const float> capture,
723 bool use_slow_smoothing) {
724 RTC_DCHECK_EQ(sub_block_size_, capture.size());
725 auto& y = capture;
726
727 const float smoothing =
728 use_slow_smoothing ? smoothing_slow_ : smoothing_fast_;
729
730 const float x2_sum_threshold =
731 filters_[0].size() * excitation_limit_ * excitation_limit_;
732
733 // Compute anchor for the matched filter error.
734 float error_sum_anchor = 0.0f;
735 for (size_t k = 0; k < y.size(); ++k) {
736 error_sum_anchor += y[k] * y[k];
737 }
738
739 // Apply all matched filters.
740 float winner_error_sum = error_sum_anchor;
741 winner_lag_ = absl::nullopt;
742 reported_lag_estimate_ = absl::nullopt;
743 size_t alignment_shift = 0;
744 absl::optional<size_t> previous_lag_estimate;
745 const int num_filters = static_cast<int>(filters_.size());
746 int winner_index = -1;
747 for (int n = 0; n < num_filters; ++n) {
748 float error_sum = 0.f;
749 bool filters_updated = false;
750 const bool compute_pre_echo =
751 detect_pre_echo_ && n == last_detected_best_lag_filter_;
752
753 size_t x_start_index =
754 (render_buffer.read + alignment_shift + sub_block_size_ - 1) %
755 render_buffer.buffer.size();
756
757 switch (optimization_) {
758 #if defined(WEBRTC_ARCH_X86_FAMILY)
759 case Aec3Optimization::kSse2:
760 aec3::MatchedFilterCore_SSE2(
761 x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y,
762 filters_[n], &filters_updated, &error_sum, compute_pre_echo,
763 instantaneous_accumulated_error_, scratch_memory_);
764 break;
765 case Aec3Optimization::kAvx2:
766 aec3::MatchedFilterCore_AVX2(
767 x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y,
768 filters_[n], &filters_updated, &error_sum, compute_pre_echo,
769 instantaneous_accumulated_error_, scratch_memory_);
770 break;
771 #endif
772 #if defined(WEBRTC_HAS_NEON)
773 case Aec3Optimization::kNeon:
774 aec3::MatchedFilterCore_NEON(
775 x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y,
776 filters_[n], &filters_updated, &error_sum, compute_pre_echo,
777 instantaneous_accumulated_error_, scratch_memory_);
778 break;
779 #endif
780 default:
781 aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, smoothing,
782 render_buffer.buffer, y, filters_[n],
783 &filters_updated, &error_sum, compute_pre_echo,
784 instantaneous_accumulated_error_);
785 }
786
787 // Estimate the lag in the matched filter as the distance to the portion in
788 // the filter that contributes the most to the matched filter output. This
789 // is detected as the peak of the matched filter.
790 const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]);
791 const bool reliable =
792 lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) &&
793 error_sum < matching_filter_threshold_ * error_sum_anchor;
794
795 // Find the best estimate
796 const size_t lag = lag_estimate + alignment_shift;
797 if (filters_updated && reliable && error_sum < winner_error_sum) {
798 winner_error_sum = error_sum;
799 winner_index = n;
800 // In case that 2 matched filters return the same winner candidate
801 // (overlap region), the one with the smaller index is chosen in order
802 // to search for pre-echoes.
803 if (previous_lag_estimate && previous_lag_estimate == lag) {
804 winner_lag_ = previous_lag_estimate;
805 winner_index = n - 1;
806 } else {
807 winner_lag_ = lag;
808 }
809 }
810 previous_lag_estimate = lag;
811 alignment_shift += filter_intra_lag_shift_;
812 }
813
814 if (winner_index != -1) {
815 RTC_DCHECK(winner_lag_.has_value());
816 reported_lag_estimate_ =
817 LagEstimate(winner_lag_.value(), /*pre_echo_lag=*/winner_lag_.value());
818 if (detect_pre_echo_ && last_detected_best_lag_filter_ == winner_index) {
819 if (error_sum_anchor > 30.0f * 30.0f * y.size()) {
820 UpdateAccumulatedError(instantaneous_accumulated_error_,
821 accumulated_error_[winner_index],
822 1.0f / error_sum_anchor);
823 }
824 reported_lag_estimate_->pre_echo_lag = ComputePreEchoLag(
825 pre_echo_config_, accumulated_error_[winner_index],
826 winner_lag_.value(),
827 winner_index * filter_intra_lag_shift_ /*alignment_shift_winner*/);
828 }
829 last_detected_best_lag_filter_ = winner_index;
830 }
831 if (ApmDataDumper::IsAvailable()) {
832 Dump();
833 }
834 }
835
LogFilterProperties(int sample_rate_hz,size_t shift,size_t downsampling_factor) const836 void MatchedFilter::LogFilterProperties(int sample_rate_hz,
837 size_t shift,
838 size_t downsampling_factor) const {
839 size_t alignment_shift = 0;
840 constexpr int kFsBy1000 = 16;
841 for (size_t k = 0; k < filters_.size(); ++k) {
842 int start = static_cast<int>(alignment_shift * downsampling_factor);
843 int end = static_cast<int>((alignment_shift + filters_[k].size()) *
844 downsampling_factor);
845 RTC_LOG(LS_VERBOSE) << "Filter " << k << ": start: "
846 << (start - static_cast<int>(shift)) / kFsBy1000
847 << " ms, end: "
848 << (end - static_cast<int>(shift)) / kFsBy1000
849 << " ms.";
850 alignment_shift += filter_intra_lag_shift_;
851 }
852 }
853
Dump()854 void MatchedFilter::Dump() {
855 for (size_t n = 0; n < filters_.size(); ++n) {
856 const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]);
857 std::string dumper_filter = "aec3_correlator_" + std::to_string(n) + "_h";
858 data_dumper_->DumpRaw(dumper_filter.c_str(), filters_[n]);
859 std::string dumper_lag = "aec3_correlator_lag_" + std::to_string(n);
860 data_dumper_->DumpRaw(dumper_lag.c_str(),
861 lag_estimate + n * filter_intra_lag_shift_);
862 if (detect_pre_echo_) {
863 std::string dumper_error =
864 "aec3_correlator_error_" + std::to_string(n) + "_h";
865 data_dumper_->DumpRaw(dumper_error.c_str(), accumulated_error_[n]);
866
867 size_t pre_echo_lag =
868 ComputePreEchoLag(pre_echo_config_, accumulated_error_[n],
869 lag_estimate + n * filter_intra_lag_shift_,
870 n * filter_intra_lag_shift_);
871 std::string dumper_pre_lag =
872 "aec3_correlator_pre_echo_lag_" + std::to_string(n);
873 data_dumper_->DumpRaw(dumper_pre_lag.c_str(), pre_echo_lag);
874 }
875 }
876 }
877
878 } // namespace webrtc
879