xref: /aosp_15_r20/external/webrtc/modules/audio_processing/aec3/matched_filter.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 #include "modules/audio_processing/aec3/matched_filter.h"
11 
12 // Defines WEBRTC_ARCH_X86_FAMILY, used below.
13 #include "rtc_base/system/arch.h"
14 
15 #if defined(WEBRTC_HAS_NEON)
16 #include <arm_neon.h>
17 #endif
18 #if defined(WEBRTC_ARCH_X86_FAMILY)
19 #include <emmintrin.h>
20 #endif
21 #include <algorithm>
22 #include <cstddef>
23 #include <initializer_list>
24 #include <iterator>
25 #include <numeric>
26 
27 #include "absl/types/optional.h"
28 #include "api/array_view.h"
29 #include "modules/audio_processing/aec3/downsampled_render_buffer.h"
30 #include "modules/audio_processing/logging/apm_data_dumper.h"
31 #include "rtc_base/checks.h"
32 #include "rtc_base/experiments/field_trial_parser.h"
33 #include "rtc_base/logging.h"
34 #include "system_wrappers/include/field_trial.h"
35 
36 namespace {
37 
38 // Subsample rate used for computing the accumulated error.
39 // The implementation of some core functions depends on this constant being
40 // equal to 4.
41 constexpr int kAccumulatedErrorSubSampleRate = 4;
42 
UpdateAccumulatedError(const rtc::ArrayView<const float> instantaneous_accumulated_error,const rtc::ArrayView<float> accumulated_error,float one_over_error_sum_anchor)43 void UpdateAccumulatedError(
44     const rtc::ArrayView<const float> instantaneous_accumulated_error,
45     const rtc::ArrayView<float> accumulated_error,
46     float one_over_error_sum_anchor) {
47   for (size_t k = 0; k < instantaneous_accumulated_error.size(); ++k) {
48     float error_norm =
49         instantaneous_accumulated_error[k] * one_over_error_sum_anchor;
50     if (error_norm < accumulated_error[k]) {
51       accumulated_error[k] = error_norm;
52     } else {
53       accumulated_error[k] += 0.01f * (error_norm - accumulated_error[k]);
54     }
55   }
56 }
57 
ComputePreEchoLag(const webrtc::MatchedFilter::PreEchoConfiguration & pre_echo_configuration,const rtc::ArrayView<const float> accumulated_error,size_t lag,size_t alignment_shift_winner)58 size_t ComputePreEchoLag(
59     const webrtc::MatchedFilter::PreEchoConfiguration& pre_echo_configuration,
60     const rtc::ArrayView<const float> accumulated_error,
61     size_t lag,
62     size_t alignment_shift_winner) {
63   RTC_DCHECK_GE(lag, alignment_shift_winner);
64   size_t pre_echo_lag_estimate = lag - alignment_shift_winner;
65   size_t maximum_pre_echo_lag =
66       std::min(pre_echo_lag_estimate / kAccumulatedErrorSubSampleRate,
67                accumulated_error.size());
68   switch (pre_echo_configuration.mode) {
69     case 0:
70       // Mode 0: Pre echo lag is defined as the first coefficient with an error
71       // lower than a threshold with a certain decrease slope.
72       for (size_t k = 1; k < maximum_pre_echo_lag; ++k) {
73         if (accumulated_error[k] <
74                 pre_echo_configuration.threshold * accumulated_error[k - 1] &&
75             accumulated_error[k] < pre_echo_configuration.threshold) {
76           pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1;
77           break;
78         }
79       }
80       break;
81     case 1:
82       // Mode 1: Pre echo lag is defined as the first coefficient with an error
83       // lower than a certain threshold.
84       for (size_t k = 0; k < maximum_pre_echo_lag; ++k) {
85         if (accumulated_error[k] < pre_echo_configuration.threshold) {
86           pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1;
87           break;
88         }
89       }
90       break;
91     case 2:
92       // Mode 2: Pre echo lag is defined as the closest coefficient to the lag
93       // with an error lower than a certain threshold.
94       for (int k = static_cast<int>(maximum_pre_echo_lag) - 1; k >= 0; --k) {
95         if (accumulated_error[k] > pre_echo_configuration.threshold) {
96           break;
97         }
98         pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1;
99       }
100       break;
101     default:
102       RTC_DCHECK_NOTREACHED();
103       break;
104   }
105   return pre_echo_lag_estimate + alignment_shift_winner;
106 }
107 
FetchPreEchoConfiguration()108 webrtc::MatchedFilter::PreEchoConfiguration FetchPreEchoConfiguration() {
109   float threshold = 0.5f;
110   int mode = 0;
111   const std::string pre_echo_configuration_field_trial =
112       webrtc::field_trial::FindFullName("WebRTC-Aec3PreEchoConfiguration");
113   webrtc::FieldTrialParameter<double> threshold_field_trial_parameter(
114       /*key=*/"threshold", /*default_value=*/threshold);
115   webrtc::FieldTrialParameter<int> mode_field_trial_parameter(
116       /*key=*/"mode", /*default_value=*/mode);
117   webrtc::ParseFieldTrial(
118       {&threshold_field_trial_parameter, &mode_field_trial_parameter},
119       pre_echo_configuration_field_trial);
120   float threshold_read =
121       static_cast<float>(threshold_field_trial_parameter.Get());
122   int mode_read = mode_field_trial_parameter.Get();
123   if (threshold_read < 1.0f && threshold_read > 0.0f) {
124     threshold = threshold_read;
125   } else {
126     RTC_LOG(LS_ERROR)
127         << "AEC3: Pre echo configuration:  wrong input, threshold = "
128         << threshold_read << ".";
129   }
130   if (mode_read >= 0 && mode_read <= 3) {
131     mode = mode_read;
132   } else {
133     RTC_LOG(LS_ERROR) << "AEC3: Pre echo configuration:  wrong input, mode = "
134                       << mode_read << ".";
135   }
136   RTC_LOG(LS_INFO) << "AEC3: Pre echo configuration:  threshold = " << threshold
137                    << ", mode =  " << mode << ".";
138   return {.threshold = threshold, .mode = mode};
139 }
140 
141 }  // namespace
142 
143 namespace webrtc {
144 namespace aec3 {
145 
146 #if defined(WEBRTC_HAS_NEON)
147 
SumAllElements(float32x4_t elements)148 inline float SumAllElements(float32x4_t elements) {
149   float32x2_t sum = vpadd_f32(vget_low_f32(elements), vget_high_f32(elements));
150   sum = vpadd_f32(sum, sum);
151   return vget_lane_f32(sum, 0);
152 }
153 
MatchedFilterCoreWithAccumulatedError_NEON(size_t x_start_index,float x2_sum_threshold,float smoothing,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum,rtc::ArrayView<float> accumulated_error,rtc::ArrayView<float> scratch_memory)154 void MatchedFilterCoreWithAccumulatedError_NEON(
155     size_t x_start_index,
156     float x2_sum_threshold,
157     float smoothing,
158     rtc::ArrayView<const float> x,
159     rtc::ArrayView<const float> y,
160     rtc::ArrayView<float> h,
161     bool* filters_updated,
162     float* error_sum,
163     rtc::ArrayView<float> accumulated_error,
164     rtc::ArrayView<float> scratch_memory) {
165   const int h_size = static_cast<int>(h.size());
166   const int x_size = static_cast<int>(x.size());
167   RTC_DCHECK_EQ(0, h_size % 4);
168   std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
169   // Process for all samples in the sub-block.
170   for (size_t i = 0; i < y.size(); ++i) {
171     // Apply the matched filter as filter * x, and compute x * x.
172     RTC_DCHECK_GT(x_size, x_start_index);
173     // Compute loop chunk sizes until, and after, the wraparound of the circular
174     // buffer for x.
175     const int chunk1 =
176         std::min(h_size, static_cast<int>(x_size - x_start_index));
177     if (chunk1 != h_size) {
178       const int chunk2 = h_size - chunk1;
179       std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin());
180       std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1);
181     }
182     const float* x_p =
183         chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
184     const float* h_p = &h[0];
185     float* accumulated_error_p = &accumulated_error[0];
186     // Initialize values for the accumulation.
187     float32x4_t x2_sum_128 = vdupq_n_f32(0);
188     float x2_sum = 0.f;
189     float s = 0;
190     // Perform 128 bit vector operations.
191     const int limit_by_4 = h_size >> 2;
192     for (int k = limit_by_4; k > 0;
193          --k, h_p += 4, x_p += 4, accumulated_error_p++) {
194       // Load the data into 128 bit vectors.
195       const float32x4_t x_k = vld1q_f32(x_p);
196       const float32x4_t h_k = vld1q_f32(h_p);
197       // Compute and accumulate x * x.
198       x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k);
199       // Compute x * h
200       float32x4_t hk_xk_128 = vmulq_f32(h_k, x_k);
201       s += SumAllElements(hk_xk_128);
202       const float e = s - y[i];
203       accumulated_error_p[0] += e * e;
204     }
205     // Combine the accumulated vector and scalar values.
206     x2_sum += SumAllElements(x2_sum_128);
207     // Compute the matched filter error.
208     float e = y[i] - s;
209     const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
210     (*error_sum) += e * e;
211     // Update the matched filter estimate in an NLMS manner.
212     if (x2_sum > x2_sum_threshold && !saturation) {
213       RTC_DCHECK_LT(0.f, x2_sum);
214       const float alpha = smoothing * e / x2_sum;
215       const float32x4_t alpha_128 = vmovq_n_f32(alpha);
216       // filter = filter + smoothing * (y - filter * x) * x / x * x.
217       float* h_p = &h[0];
218       x_p = chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
219       // Perform 128 bit vector operations.
220       const int limit_by_4 = h_size >> 2;
221       for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
222         // Load the data into 128 bit vectors.
223         float32x4_t h_k = vld1q_f32(h_p);
224         const float32x4_t x_k = vld1q_f32(x_p);
225         // Compute h = h + alpha * x.
226         h_k = vmlaq_f32(h_k, alpha_128, x_k);
227         // Store the result.
228         vst1q_f32(h_p, h_k);
229       }
230       *filters_updated = true;
231     }
232     x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
233   }
234 }
235 
MatchedFilterCore_NEON(size_t x_start_index,float x2_sum_threshold,float smoothing,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum,bool compute_accumulated_error,rtc::ArrayView<float> accumulated_error,rtc::ArrayView<float> scratch_memory)236 void MatchedFilterCore_NEON(size_t x_start_index,
237                             float x2_sum_threshold,
238                             float smoothing,
239                             rtc::ArrayView<const float> x,
240                             rtc::ArrayView<const float> y,
241                             rtc::ArrayView<float> h,
242                             bool* filters_updated,
243                             float* error_sum,
244                             bool compute_accumulated_error,
245                             rtc::ArrayView<float> accumulated_error,
246                             rtc::ArrayView<float> scratch_memory) {
247   const int h_size = static_cast<int>(h.size());
248   const int x_size = static_cast<int>(x.size());
249   RTC_DCHECK_EQ(0, h_size % 4);
250 
251   if (compute_accumulated_error) {
252     return MatchedFilterCoreWithAccumulatedError_NEON(
253         x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated,
254         error_sum, accumulated_error, scratch_memory);
255   }
256 
257   // Process for all samples in the sub-block.
258   for (size_t i = 0; i < y.size(); ++i) {
259     // Apply the matched filter as filter * x, and compute x * x.
260 
261     RTC_DCHECK_GT(x_size, x_start_index);
262     const float* x_p = &x[x_start_index];
263     const float* h_p = &h[0];
264 
265     // Initialize values for the accumulation.
266     float32x4_t s_128 = vdupq_n_f32(0);
267     float32x4_t x2_sum_128 = vdupq_n_f32(0);
268     float x2_sum = 0.f;
269     float s = 0;
270 
271     // Compute loop chunk sizes until, and after, the wraparound of the circular
272     // buffer for x.
273     const int chunk1 =
274         std::min(h_size, static_cast<int>(x_size - x_start_index));
275 
276     // Perform the loop in two chunks.
277     const int chunk2 = h_size - chunk1;
278     for (int limit : {chunk1, chunk2}) {
279       // Perform 128 bit vector operations.
280       const int limit_by_4 = limit >> 2;
281       for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
282         // Load the data into 128 bit vectors.
283         const float32x4_t x_k = vld1q_f32(x_p);
284         const float32x4_t h_k = vld1q_f32(h_p);
285         // Compute and accumulate x * x and h * x.
286         x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k);
287         s_128 = vmlaq_f32(s_128, h_k, x_k);
288       }
289 
290       // Perform non-vector operations for any remaining items.
291       for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
292         const float x_k = *x_p;
293         x2_sum += x_k * x_k;
294         s += *h_p * x_k;
295       }
296 
297       x_p = &x[0];
298     }
299 
300     // Combine the accumulated vector and scalar values.
301     s += SumAllElements(s_128);
302     x2_sum += SumAllElements(x2_sum_128);
303 
304     // Compute the matched filter error.
305     float e = y[i] - s;
306     const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
307     (*error_sum) += e * e;
308 
309     // Update the matched filter estimate in an NLMS manner.
310     if (x2_sum > x2_sum_threshold && !saturation) {
311       RTC_DCHECK_LT(0.f, x2_sum);
312       const float alpha = smoothing * e / x2_sum;
313       const float32x4_t alpha_128 = vmovq_n_f32(alpha);
314 
315       // filter = filter + smoothing * (y - filter * x) * x / x * x.
316       float* h_p = &h[0];
317       x_p = &x[x_start_index];
318 
319       // Perform the loop in two chunks.
320       for (int limit : {chunk1, chunk2}) {
321         // Perform 128 bit vector operations.
322         const int limit_by_4 = limit >> 2;
323         for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
324           // Load the data into 128 bit vectors.
325           float32x4_t h_k = vld1q_f32(h_p);
326           const float32x4_t x_k = vld1q_f32(x_p);
327           // Compute h = h + alpha * x.
328           h_k = vmlaq_f32(h_k, alpha_128, x_k);
329 
330           // Store the result.
331           vst1q_f32(h_p, h_k);
332         }
333 
334         // Perform non-vector operations for any remaining items.
335         for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
336           *h_p += alpha * *x_p;
337         }
338 
339         x_p = &x[0];
340       }
341 
342       *filters_updated = true;
343     }
344 
345     x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
346   }
347 }
348 
349 #endif
350 
351 #if defined(WEBRTC_ARCH_X86_FAMILY)
352 
MatchedFilterCore_AccumulatedError_SSE2(size_t x_start_index,float x2_sum_threshold,float smoothing,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum,rtc::ArrayView<float> accumulated_error,rtc::ArrayView<float> scratch_memory)353 void MatchedFilterCore_AccumulatedError_SSE2(
354     size_t x_start_index,
355     float x2_sum_threshold,
356     float smoothing,
357     rtc::ArrayView<const float> x,
358     rtc::ArrayView<const float> y,
359     rtc::ArrayView<float> h,
360     bool* filters_updated,
361     float* error_sum,
362     rtc::ArrayView<float> accumulated_error,
363     rtc::ArrayView<float> scratch_memory) {
364   const int h_size = static_cast<int>(h.size());
365   const int x_size = static_cast<int>(x.size());
366   RTC_DCHECK_EQ(0, h_size % 8);
367   std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
368   // Process for all samples in the sub-block.
369   for (size_t i = 0; i < y.size(); ++i) {
370     // Apply the matched filter as filter * x, and compute x * x.
371     RTC_DCHECK_GT(x_size, x_start_index);
372     const int chunk1 =
373         std::min(h_size, static_cast<int>(x_size - x_start_index));
374     if (chunk1 != h_size) {
375       const int chunk2 = h_size - chunk1;
376       std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin());
377       std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1);
378     }
379     const float* x_p =
380         chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
381     const float* h_p = &h[0];
382     float* a_p = &accumulated_error[0];
383     __m128 s_inst_128;
384     __m128 s_inst_128_4;
385     __m128 x2_sum_128 = _mm_set1_ps(0);
386     __m128 x2_sum_128_4 = _mm_set1_ps(0);
387     __m128 e_128;
388     float* const s_p = reinterpret_cast<float*>(&s_inst_128);
389     float* const s_4_p = reinterpret_cast<float*>(&s_inst_128_4);
390     float* const e_p = reinterpret_cast<float*>(&e_128);
391     float x2_sum = 0.0f;
392     float s_acum = 0;
393     // Perform 128 bit vector operations.
394     const int limit_by_8 = h_size >> 3;
395     for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8, a_p += 2) {
396       // Load the data into 128 bit vectors.
397       const __m128 x_k = _mm_loadu_ps(x_p);
398       const __m128 h_k = _mm_loadu_ps(h_p);
399       const __m128 x_k_4 = _mm_loadu_ps(x_p + 4);
400       const __m128 h_k_4 = _mm_loadu_ps(h_p + 4);
401       const __m128 xx = _mm_mul_ps(x_k, x_k);
402       const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4);
403       // Compute and accumulate x * x and h * x.
404       x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
405       x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4);
406       s_inst_128 = _mm_mul_ps(h_k, x_k);
407       s_inst_128_4 = _mm_mul_ps(h_k_4, x_k_4);
408       s_acum += s_p[0] + s_p[1] + s_p[2] + s_p[3];
409       e_p[0] = s_acum - y[i];
410       s_acum += s_4_p[0] + s_4_p[1] + s_4_p[2] + s_4_p[3];
411       e_p[1] = s_acum - y[i];
412       a_p[0] += e_p[0] * e_p[0];
413       a_p[1] += e_p[1] * e_p[1];
414     }
415     // Combine the accumulated vector and scalar values.
416     x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4);
417     float* v = reinterpret_cast<float*>(&x2_sum_128);
418     x2_sum += v[0] + v[1] + v[2] + v[3];
419     // Compute the matched filter error.
420     float e = y[i] - s_acum;
421     const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
422     (*error_sum) += e * e;
423     // Update the matched filter estimate in an NLMS manner.
424     if (x2_sum > x2_sum_threshold && !saturation) {
425       RTC_DCHECK_LT(0.f, x2_sum);
426       const float alpha = smoothing * e / x2_sum;
427       const __m128 alpha_128 = _mm_set1_ps(alpha);
428       // filter = filter + smoothing * (y - filter * x) * x / x * x.
429       float* h_p = &h[0];
430       const float* x_p =
431           chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
432       // Perform 128 bit vector operations.
433       const int limit_by_4 = h_size >> 2;
434       for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
435         // Load the data into 128 bit vectors.
436         __m128 h_k = _mm_loadu_ps(h_p);
437         const __m128 x_k = _mm_loadu_ps(x_p);
438         // Compute h = h + alpha * x.
439         const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
440         h_k = _mm_add_ps(h_k, alpha_x);
441         // Store the result.
442         _mm_storeu_ps(h_p, h_k);
443       }
444       *filters_updated = true;
445     }
446     x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
447   }
448 }
449 
MatchedFilterCore_SSE2(size_t x_start_index,float x2_sum_threshold,float smoothing,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum,bool compute_accumulated_error,rtc::ArrayView<float> accumulated_error,rtc::ArrayView<float> scratch_memory)450 void MatchedFilterCore_SSE2(size_t x_start_index,
451                             float x2_sum_threshold,
452                             float smoothing,
453                             rtc::ArrayView<const float> x,
454                             rtc::ArrayView<const float> y,
455                             rtc::ArrayView<float> h,
456                             bool* filters_updated,
457                             float* error_sum,
458                             bool compute_accumulated_error,
459                             rtc::ArrayView<float> accumulated_error,
460                             rtc::ArrayView<float> scratch_memory) {
461   if (compute_accumulated_error) {
462     return MatchedFilterCore_AccumulatedError_SSE2(
463         x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated,
464         error_sum, accumulated_error, scratch_memory);
465   }
466   const int h_size = static_cast<int>(h.size());
467   const int x_size = static_cast<int>(x.size());
468   RTC_DCHECK_EQ(0, h_size % 4);
469   // Process for all samples in the sub-block.
470   for (size_t i = 0; i < y.size(); ++i) {
471     // Apply the matched filter as filter * x, and compute x * x.
472     RTC_DCHECK_GT(x_size, x_start_index);
473     const float* x_p = &x[x_start_index];
474     const float* h_p = &h[0];
475     // Initialize values for the accumulation.
476     __m128 s_128 = _mm_set1_ps(0);
477     __m128 s_128_4 = _mm_set1_ps(0);
478     __m128 x2_sum_128 = _mm_set1_ps(0);
479     __m128 x2_sum_128_4 = _mm_set1_ps(0);
480     float x2_sum = 0.f;
481     float s = 0;
482     // Compute loop chunk sizes until, and after, the wraparound of the circular
483     // buffer for x.
484     const int chunk1 =
485         std::min(h_size, static_cast<int>(x_size - x_start_index));
486     // Perform the loop in two chunks.
487     const int chunk2 = h_size - chunk1;
488     for (int limit : {chunk1, chunk2}) {
489       // Perform 128 bit vector operations.
490       const int limit_by_8 = limit >> 3;
491       for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) {
492         // Load the data into 128 bit vectors.
493         const __m128 x_k = _mm_loadu_ps(x_p);
494         const __m128 h_k = _mm_loadu_ps(h_p);
495         const __m128 x_k_4 = _mm_loadu_ps(x_p + 4);
496         const __m128 h_k_4 = _mm_loadu_ps(h_p + 4);
497         const __m128 xx = _mm_mul_ps(x_k, x_k);
498         const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4);
499         // Compute and accumulate x * x and h * x.
500         x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
501         x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4);
502         const __m128 hx = _mm_mul_ps(h_k, x_k);
503         const __m128 hx_4 = _mm_mul_ps(h_k_4, x_k_4);
504         s_128 = _mm_add_ps(s_128, hx);
505         s_128_4 = _mm_add_ps(s_128_4, hx_4);
506       }
507       // Perform non-vector operations for any remaining items.
508       for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) {
509         const float x_k = *x_p;
510         x2_sum += x_k * x_k;
511         s += *h_p * x_k;
512       }
513       x_p = &x[0];
514     }
515     // Combine the accumulated vector and scalar values.
516     x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4);
517     float* v = reinterpret_cast<float*>(&x2_sum_128);
518     x2_sum += v[0] + v[1] + v[2] + v[3];
519     s_128 = _mm_add_ps(s_128, s_128_4);
520     v = reinterpret_cast<float*>(&s_128);
521     s += v[0] + v[1] + v[2] + v[3];
522     // Compute the matched filter error.
523     float e = y[i] - s;
524     const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
525     (*error_sum) += e * e;
526     // Update the matched filter estimate in an NLMS manner.
527     if (x2_sum > x2_sum_threshold && !saturation) {
528       RTC_DCHECK_LT(0.f, x2_sum);
529       const float alpha = smoothing * e / x2_sum;
530       const __m128 alpha_128 = _mm_set1_ps(alpha);
531       // filter = filter + smoothing * (y - filter * x) * x / x * x.
532       float* h_p = &h[0];
533       x_p = &x[x_start_index];
534       // Perform the loop in two chunks.
535       for (int limit : {chunk1, chunk2}) {
536         // Perform 128 bit vector operations.
537         const int limit_by_4 = limit >> 2;
538         for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
539           // Load the data into 128 bit vectors.
540           __m128 h_k = _mm_loadu_ps(h_p);
541           const __m128 x_k = _mm_loadu_ps(x_p);
542 
543           // Compute h = h + alpha * x.
544           const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
545           h_k = _mm_add_ps(h_k, alpha_x);
546           // Store the result.
547           _mm_storeu_ps(h_p, h_k);
548         }
549         // Perform non-vector operations for any remaining items.
550         for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
551           *h_p += alpha * *x_p;
552         }
553         x_p = &x[0];
554       }
555       *filters_updated = true;
556     }
557     x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
558   }
559 }
560 #endif
561 
MatchedFilterCore(size_t x_start_index,float x2_sum_threshold,float smoothing,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum,bool compute_accumulated_error,rtc::ArrayView<float> accumulated_error)562 void MatchedFilterCore(size_t x_start_index,
563                        float x2_sum_threshold,
564                        float smoothing,
565                        rtc::ArrayView<const float> x,
566                        rtc::ArrayView<const float> y,
567                        rtc::ArrayView<float> h,
568                        bool* filters_updated,
569                        float* error_sum,
570                        bool compute_accumulated_error,
571                        rtc::ArrayView<float> accumulated_error) {
572   if (compute_accumulated_error) {
573     std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
574   }
575 
576   // Process for all samples in the sub-block.
577   for (size_t i = 0; i < y.size(); ++i) {
578     // Apply the matched filter as filter * x, and compute x * x.
579     float x2_sum = 0.f;
580     float s = 0;
581     size_t x_index = x_start_index;
582     if (compute_accumulated_error) {
583       for (size_t k = 0; k < h.size(); ++k) {
584         x2_sum += x[x_index] * x[x_index];
585         s += h[k] * x[x_index];
586         x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
587         if ((k + 1 & 0b11) == 0) {
588           int idx = k >> 2;
589           accumulated_error[idx] += (y[i] - s) * (y[i] - s);
590         }
591       }
592     } else {
593       for (size_t k = 0; k < h.size(); ++k) {
594         x2_sum += x[x_index] * x[x_index];
595         s += h[k] * x[x_index];
596         x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
597       }
598     }
599 
600     // Compute the matched filter error.
601     float e = y[i] - s;
602     const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
603     (*error_sum) += e * e;
604 
605     // Update the matched filter estimate in an NLMS manner.
606     if (x2_sum > x2_sum_threshold && !saturation) {
607       RTC_DCHECK_LT(0.f, x2_sum);
608       const float alpha = smoothing * e / x2_sum;
609 
610       // filter = filter + smoothing * (y - filter * x) * x / x * x.
611       size_t x_index = x_start_index;
612       for (size_t k = 0; k < h.size(); ++k) {
613         h[k] += alpha * x[x_index];
614         x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
615       }
616       *filters_updated = true;
617     }
618 
619     x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1;
620   }
621 }
622 
MaxSquarePeakIndex(rtc::ArrayView<const float> h)623 size_t MaxSquarePeakIndex(rtc::ArrayView<const float> h) {
624   if (h.size() < 2) {
625     return 0;
626   }
627   float max_element1 = h[0] * h[0];
628   float max_element2 = h[1] * h[1];
629   size_t lag_estimate1 = 0;
630   size_t lag_estimate2 = 1;
631   const size_t last_index = h.size() - 1;
632   // Keeping track of even & odd max elements separately typically allows the
633   // compiler to produce more efficient code.
634   for (size_t k = 2; k < last_index; k += 2) {
635     float element1 = h[k] * h[k];
636     float element2 = h[k + 1] * h[k + 1];
637     if (element1 > max_element1) {
638       max_element1 = element1;
639       lag_estimate1 = k;
640     }
641     if (element2 > max_element2) {
642       max_element2 = element2;
643       lag_estimate2 = k + 1;
644     }
645   }
646   if (max_element2 > max_element1) {
647     max_element1 = max_element2;
648     lag_estimate1 = lag_estimate2;
649   }
650   // In case of odd h size, we have not yet checked the last element.
651   float last_element = h[last_index] * h[last_index];
652   if (last_element > max_element1) {
653     return last_index;
654   }
655   return lag_estimate1;
656 }
657 
658 }  // namespace aec3
659 
MatchedFilter(ApmDataDumper * data_dumper,Aec3Optimization optimization,size_t sub_block_size,size_t window_size_sub_blocks,int num_matched_filters,size_t alignment_shift_sub_blocks,float excitation_limit,float smoothing_fast,float smoothing_slow,float matching_filter_threshold,bool detect_pre_echo)660 MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
661                              Aec3Optimization optimization,
662                              size_t sub_block_size,
663                              size_t window_size_sub_blocks,
664                              int num_matched_filters,
665                              size_t alignment_shift_sub_blocks,
666                              float excitation_limit,
667                              float smoothing_fast,
668                              float smoothing_slow,
669                              float matching_filter_threshold,
670                              bool detect_pre_echo)
671     : data_dumper_(data_dumper),
672       optimization_(optimization),
673       sub_block_size_(sub_block_size),
674       filter_intra_lag_shift_(alignment_shift_sub_blocks * sub_block_size_),
675       filters_(
676           num_matched_filters,
677           std::vector<float>(window_size_sub_blocks * sub_block_size_, 0.f)),
678       filters_offsets_(num_matched_filters, 0),
679       excitation_limit_(excitation_limit),
680       smoothing_fast_(smoothing_fast),
681       smoothing_slow_(smoothing_slow),
682       matching_filter_threshold_(matching_filter_threshold),
683       detect_pre_echo_(detect_pre_echo),
684       pre_echo_config_(FetchPreEchoConfiguration()) {
685   RTC_DCHECK(data_dumper);
686   RTC_DCHECK_LT(0, window_size_sub_blocks);
687   RTC_DCHECK((kBlockSize % sub_block_size) == 0);
688   RTC_DCHECK((sub_block_size % 4) == 0);
689   static_assert(kAccumulatedErrorSubSampleRate == 4);
690   if (detect_pre_echo_) {
691     accumulated_error_ = std::vector<std::vector<float>>(
692         num_matched_filters,
693         std::vector<float>(window_size_sub_blocks * sub_block_size_ /
694                                kAccumulatedErrorSubSampleRate,
695                            1.0f));
696 
697     instantaneous_accumulated_error_ =
698         std::vector<float>(window_size_sub_blocks * sub_block_size_ /
699                                kAccumulatedErrorSubSampleRate,
700                            0.0f);
701     scratch_memory_ =
702         std::vector<float>(window_size_sub_blocks * sub_block_size_);
703   }
704 }
705 
706 MatchedFilter::~MatchedFilter() = default;
707 
Reset()708 void MatchedFilter::Reset() {
709   for (auto& f : filters_) {
710     std::fill(f.begin(), f.end(), 0.f);
711   }
712 
713   for (auto& e : accumulated_error_) {
714     std::fill(e.begin(), e.end(), 1.0f);
715   }
716 
717   winner_lag_ = absl::nullopt;
718   reported_lag_estimate_ = absl::nullopt;
719 }
720 
Update(const DownsampledRenderBuffer & render_buffer,rtc::ArrayView<const float> capture,bool use_slow_smoothing)721 void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
722                            rtc::ArrayView<const float> capture,
723                            bool use_slow_smoothing) {
724   RTC_DCHECK_EQ(sub_block_size_, capture.size());
725   auto& y = capture;
726 
727   const float smoothing =
728       use_slow_smoothing ? smoothing_slow_ : smoothing_fast_;
729 
730   const float x2_sum_threshold =
731       filters_[0].size() * excitation_limit_ * excitation_limit_;
732 
733   // Compute anchor for the matched filter error.
734   float error_sum_anchor = 0.0f;
735   for (size_t k = 0; k < y.size(); ++k) {
736     error_sum_anchor += y[k] * y[k];
737   }
738 
739   // Apply all matched filters.
740   float winner_error_sum = error_sum_anchor;
741   winner_lag_ = absl::nullopt;
742   reported_lag_estimate_ = absl::nullopt;
743   size_t alignment_shift = 0;
744   absl::optional<size_t> previous_lag_estimate;
745   const int num_filters = static_cast<int>(filters_.size());
746   int winner_index = -1;
747   for (int n = 0; n < num_filters; ++n) {
748     float error_sum = 0.f;
749     bool filters_updated = false;
750     const bool compute_pre_echo =
751         detect_pre_echo_ && n == last_detected_best_lag_filter_;
752 
753     size_t x_start_index =
754         (render_buffer.read + alignment_shift + sub_block_size_ - 1) %
755         render_buffer.buffer.size();
756 
757     switch (optimization_) {
758 #if defined(WEBRTC_ARCH_X86_FAMILY)
759       case Aec3Optimization::kSse2:
760         aec3::MatchedFilterCore_SSE2(
761             x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y,
762             filters_[n], &filters_updated, &error_sum, compute_pre_echo,
763             instantaneous_accumulated_error_, scratch_memory_);
764         break;
765       case Aec3Optimization::kAvx2:
766         aec3::MatchedFilterCore_AVX2(
767             x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y,
768             filters_[n], &filters_updated, &error_sum, compute_pre_echo,
769             instantaneous_accumulated_error_, scratch_memory_);
770         break;
771 #endif
772 #if defined(WEBRTC_HAS_NEON)
773       case Aec3Optimization::kNeon:
774         aec3::MatchedFilterCore_NEON(
775             x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y,
776             filters_[n], &filters_updated, &error_sum, compute_pre_echo,
777             instantaneous_accumulated_error_, scratch_memory_);
778         break;
779 #endif
780       default:
781         aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, smoothing,
782                                 render_buffer.buffer, y, filters_[n],
783                                 &filters_updated, &error_sum, compute_pre_echo,
784                                 instantaneous_accumulated_error_);
785     }
786 
787     // Estimate the lag in the matched filter as the distance to the portion in
788     // the filter that contributes the most to the matched filter output. This
789     // is detected as the peak of the matched filter.
790     const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]);
791     const bool reliable =
792         lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) &&
793         error_sum < matching_filter_threshold_ * error_sum_anchor;
794 
795     // Find the best estimate
796     const size_t lag = lag_estimate + alignment_shift;
797     if (filters_updated && reliable && error_sum < winner_error_sum) {
798       winner_error_sum = error_sum;
799       winner_index = n;
800       // In case that 2 matched filters return the same winner candidate
801       // (overlap region), the one with the smaller index is chosen in order
802       // to search for pre-echoes.
803       if (previous_lag_estimate && previous_lag_estimate == lag) {
804         winner_lag_ = previous_lag_estimate;
805         winner_index = n - 1;
806       } else {
807         winner_lag_ = lag;
808       }
809     }
810     previous_lag_estimate = lag;
811     alignment_shift += filter_intra_lag_shift_;
812   }
813 
814   if (winner_index != -1) {
815     RTC_DCHECK(winner_lag_.has_value());
816     reported_lag_estimate_ =
817         LagEstimate(winner_lag_.value(), /*pre_echo_lag=*/winner_lag_.value());
818     if (detect_pre_echo_ && last_detected_best_lag_filter_ == winner_index) {
819       if (error_sum_anchor > 30.0f * 30.0f * y.size()) {
820         UpdateAccumulatedError(instantaneous_accumulated_error_,
821                                accumulated_error_[winner_index],
822                                1.0f / error_sum_anchor);
823       }
824       reported_lag_estimate_->pre_echo_lag = ComputePreEchoLag(
825           pre_echo_config_, accumulated_error_[winner_index],
826           winner_lag_.value(),
827           winner_index * filter_intra_lag_shift_ /*alignment_shift_winner*/);
828     }
829     last_detected_best_lag_filter_ = winner_index;
830   }
831   if (ApmDataDumper::IsAvailable()) {
832     Dump();
833   }
834 }
835 
LogFilterProperties(int sample_rate_hz,size_t shift,size_t downsampling_factor) const836 void MatchedFilter::LogFilterProperties(int sample_rate_hz,
837                                         size_t shift,
838                                         size_t downsampling_factor) const {
839   size_t alignment_shift = 0;
840   constexpr int kFsBy1000 = 16;
841   for (size_t k = 0; k < filters_.size(); ++k) {
842     int start = static_cast<int>(alignment_shift * downsampling_factor);
843     int end = static_cast<int>((alignment_shift + filters_[k].size()) *
844                                downsampling_factor);
845     RTC_LOG(LS_VERBOSE) << "Filter " << k << ": start: "
846                         << (start - static_cast<int>(shift)) / kFsBy1000
847                         << " ms, end: "
848                         << (end - static_cast<int>(shift)) / kFsBy1000
849                         << " ms.";
850     alignment_shift += filter_intra_lag_shift_;
851   }
852 }
853 
Dump()854 void MatchedFilter::Dump() {
855   for (size_t n = 0; n < filters_.size(); ++n) {
856     const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]);
857     std::string dumper_filter = "aec3_correlator_" + std::to_string(n) + "_h";
858     data_dumper_->DumpRaw(dumper_filter.c_str(), filters_[n]);
859     std::string dumper_lag = "aec3_correlator_lag_" + std::to_string(n);
860     data_dumper_->DumpRaw(dumper_lag.c_str(),
861                           lag_estimate + n * filter_intra_lag_shift_);
862     if (detect_pre_echo_) {
863       std::string dumper_error =
864           "aec3_correlator_error_" + std::to_string(n) + "_h";
865       data_dumper_->DumpRaw(dumper_error.c_str(), accumulated_error_[n]);
866 
867       size_t pre_echo_lag =
868           ComputePreEchoLag(pre_echo_config_, accumulated_error_[n],
869                             lag_estimate + n * filter_intra_lag_shift_,
870                             n * filter_intra_lag_shift_);
871       std::string dumper_pre_lag =
872           "aec3_correlator_pre_echo_lag_" + std::to_string(n);
873       data_dumper_->DumpRaw(dumper_pre_lag.c_str(), pre_echo_lag);
874     }
875   }
876 }
877 
878 }  // namespace webrtc
879