xref: /aosp_15_r20/external/webrtc/modules/audio_processing/aec3/matched_filter.h (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #ifndef MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_
12 #define MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_
13 
14 #include <stddef.h>
15 
16 #include <vector>
17 
18 #include "absl/types/optional.h"
19 #include "api/array_view.h"
20 #include "modules/audio_processing/aec3/aec3_common.h"
21 #include "rtc_base/gtest_prod_util.h"
22 #include "rtc_base/system/arch.h"
23 
24 namespace webrtc {
25 
26 class ApmDataDumper;
27 struct DownsampledRenderBuffer;
28 
29 namespace aec3 {
30 
31 #if defined(WEBRTC_HAS_NEON)
32 
33 // Filter core for the matched filter that is optimized for NEON.
34 void MatchedFilterCore_NEON(size_t x_start_index,
35                             float x2_sum_threshold,
36                             float smoothing,
37                             rtc::ArrayView<const float> x,
38                             rtc::ArrayView<const float> y,
39                             rtc::ArrayView<float> h,
40                             bool* filters_updated,
41                             float* error_sum,
42                             bool compute_accumulation_error,
43                             rtc::ArrayView<float> accumulated_error,
44                             rtc::ArrayView<float> scratch_memory);
45 
46 #endif
47 
48 #if defined(WEBRTC_ARCH_X86_FAMILY)
49 
50 // Filter core for the matched filter that is optimized for SSE2.
51 void MatchedFilterCore_SSE2(size_t x_start_index,
52                             float x2_sum_threshold,
53                             float smoothing,
54                             rtc::ArrayView<const float> x,
55                             rtc::ArrayView<const float> y,
56                             rtc::ArrayView<float> h,
57                             bool* filters_updated,
58                             float* error_sum,
59                             bool compute_accumulated_error,
60                             rtc::ArrayView<float> accumulated_error,
61                             rtc::ArrayView<float> scratch_memory);
62 
63 // Filter core for the matched filter that is optimized for AVX2.
64 void MatchedFilterCore_AVX2(size_t x_start_index,
65                             float x2_sum_threshold,
66                             float smoothing,
67                             rtc::ArrayView<const float> x,
68                             rtc::ArrayView<const float> y,
69                             rtc::ArrayView<float> h,
70                             bool* filters_updated,
71                             float* error_sum,
72                             bool compute_accumulated_error,
73                             rtc::ArrayView<float> accumulated_error,
74                             rtc::ArrayView<float> scratch_memory);
75 
76 #endif
77 
78 // Filter core for the matched filter.
79 void MatchedFilterCore(size_t x_start_index,
80                        float x2_sum_threshold,
81                        float smoothing,
82                        rtc::ArrayView<const float> x,
83                        rtc::ArrayView<const float> y,
84                        rtc::ArrayView<float> h,
85                        bool* filters_updated,
86                        float* error_sum,
87                        bool compute_accumulation_error,
88                        rtc::ArrayView<float> accumulated_error);
89 
90 // Find largest peak of squared values in array.
91 size_t MaxSquarePeakIndex(rtc::ArrayView<const float> h);
92 
93 }  // namespace aec3
94 
95 // Produces recursively updated cross-correlation estimates for several signal
96 // shifts where the intra-shift spacing is uniform.
97 class MatchedFilter {
98  public:
99   // Stores properties for the lag estimate corresponding to a particular signal
100   // shift.
101   struct LagEstimate {
102     LagEstimate() = default;
LagEstimateLagEstimate103     LagEstimate(size_t lag, size_t pre_echo_lag)
104         : lag(lag), pre_echo_lag(pre_echo_lag) {}
105     size_t lag = 0;
106     size_t pre_echo_lag = 0;
107   };
108 
109   struct PreEchoConfiguration {
110     const float threshold;
111     const int mode;
112   };
113 
114   MatchedFilter(ApmDataDumper* data_dumper,
115                 Aec3Optimization optimization,
116                 size_t sub_block_size,
117                 size_t window_size_sub_blocks,
118                 int num_matched_filters,
119                 size_t alignment_shift_sub_blocks,
120                 float excitation_limit,
121                 float smoothing_fast,
122                 float smoothing_slow,
123                 float matching_filter_threshold,
124                 bool detect_pre_echo);
125 
126   MatchedFilter() = delete;
127   MatchedFilter(const MatchedFilter&) = delete;
128   MatchedFilter& operator=(const MatchedFilter&) = delete;
129 
130   ~MatchedFilter();
131 
132   // Updates the correlation with the values in the capture buffer.
133   void Update(const DownsampledRenderBuffer& render_buffer,
134               rtc::ArrayView<const float> capture,
135               bool use_slow_smoothing);
136 
137   // Resets the matched filter.
138   void Reset();
139 
140   // Returns the current lag estimates.
GetBestLagEstimate()141   absl::optional<const MatchedFilter::LagEstimate> GetBestLagEstimate() const {
142     return reported_lag_estimate_;
143   }
144 
145   // Returns the maximum filter lag.
GetMaxFilterLag()146   size_t GetMaxFilterLag() const {
147     return filters_.size() * filter_intra_lag_shift_ + filters_[0].size();
148   }
149 
150   // Log matched filter properties.
151   void LogFilterProperties(int sample_rate_hz,
152                            size_t shift,
153                            size_t downsampling_factor) const;
154 
155  private:
156   FRIEND_TEST_ALL_PREFIXES(MatchedFilterFieldTrialTest,
157                            PreEchoConfigurationTest);
158   FRIEND_TEST_ALL_PREFIXES(MatchedFilterFieldTrialTest,
159                            WrongPreEchoConfigurationTest);
160 
161   // Only for testing. Gets the pre echo detection configuration.
GetPreEchoConfiguration()162   const PreEchoConfiguration& GetPreEchoConfiguration() const {
163     return pre_echo_config_;
164   }
165   void Dump();
166 
167   ApmDataDumper* const data_dumper_;
168   const Aec3Optimization optimization_;
169   const size_t sub_block_size_;
170   const size_t filter_intra_lag_shift_;
171   std::vector<std::vector<float>> filters_;
172   std::vector<std::vector<float>> accumulated_error_;
173   std::vector<float> instantaneous_accumulated_error_;
174   std::vector<float> scratch_memory_;
175   absl::optional<MatchedFilter::LagEstimate> reported_lag_estimate_;
176   absl::optional<size_t> winner_lag_;
177   int last_detected_best_lag_filter_ = -1;
178   std::vector<size_t> filters_offsets_;
179   const float excitation_limit_;
180   const float smoothing_fast_;
181   const float smoothing_slow_;
182   const float matching_filter_threshold_;
183   const bool detect_pre_echo_;
184   const PreEchoConfiguration pre_echo_config_;
185 };
186 
187 }  // namespace webrtc
188 
189 #endif  // MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_
190