1 /* 2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #ifndef MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_ 12 #define MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_ 13 14 #include <stddef.h> 15 16 #include <vector> 17 18 #include "absl/types/optional.h" 19 #include "api/array_view.h" 20 #include "modules/audio_processing/aec3/aec3_common.h" 21 #include "rtc_base/gtest_prod_util.h" 22 #include "rtc_base/system/arch.h" 23 24 namespace webrtc { 25 26 class ApmDataDumper; 27 struct DownsampledRenderBuffer; 28 29 namespace aec3 { 30 31 #if defined(WEBRTC_HAS_NEON) 32 33 // Filter core for the matched filter that is optimized for NEON. 34 void MatchedFilterCore_NEON(size_t x_start_index, 35 float x2_sum_threshold, 36 float smoothing, 37 rtc::ArrayView<const float> x, 38 rtc::ArrayView<const float> y, 39 rtc::ArrayView<float> h, 40 bool* filters_updated, 41 float* error_sum, 42 bool compute_accumulation_error, 43 rtc::ArrayView<float> accumulated_error, 44 rtc::ArrayView<float> scratch_memory); 45 46 #endif 47 48 #if defined(WEBRTC_ARCH_X86_FAMILY) 49 50 // Filter core for the matched filter that is optimized for SSE2. 51 void MatchedFilterCore_SSE2(size_t x_start_index, 52 float x2_sum_threshold, 53 float smoothing, 54 rtc::ArrayView<const float> x, 55 rtc::ArrayView<const float> y, 56 rtc::ArrayView<float> h, 57 bool* filters_updated, 58 float* error_sum, 59 bool compute_accumulated_error, 60 rtc::ArrayView<float> accumulated_error, 61 rtc::ArrayView<float> scratch_memory); 62 63 // Filter core for the matched filter that is optimized for AVX2. 64 void MatchedFilterCore_AVX2(size_t x_start_index, 65 float x2_sum_threshold, 66 float smoothing, 67 rtc::ArrayView<const float> x, 68 rtc::ArrayView<const float> y, 69 rtc::ArrayView<float> h, 70 bool* filters_updated, 71 float* error_sum, 72 bool compute_accumulated_error, 73 rtc::ArrayView<float> accumulated_error, 74 rtc::ArrayView<float> scratch_memory); 75 76 #endif 77 78 // Filter core for the matched filter. 79 void MatchedFilterCore(size_t x_start_index, 80 float x2_sum_threshold, 81 float smoothing, 82 rtc::ArrayView<const float> x, 83 rtc::ArrayView<const float> y, 84 rtc::ArrayView<float> h, 85 bool* filters_updated, 86 float* error_sum, 87 bool compute_accumulation_error, 88 rtc::ArrayView<float> accumulated_error); 89 90 // Find largest peak of squared values in array. 91 size_t MaxSquarePeakIndex(rtc::ArrayView<const float> h); 92 93 } // namespace aec3 94 95 // Produces recursively updated cross-correlation estimates for several signal 96 // shifts where the intra-shift spacing is uniform. 97 class MatchedFilter { 98 public: 99 // Stores properties for the lag estimate corresponding to a particular signal 100 // shift. 101 struct LagEstimate { 102 LagEstimate() = default; LagEstimateLagEstimate103 LagEstimate(size_t lag, size_t pre_echo_lag) 104 : lag(lag), pre_echo_lag(pre_echo_lag) {} 105 size_t lag = 0; 106 size_t pre_echo_lag = 0; 107 }; 108 109 struct PreEchoConfiguration { 110 const float threshold; 111 const int mode; 112 }; 113 114 MatchedFilter(ApmDataDumper* data_dumper, 115 Aec3Optimization optimization, 116 size_t sub_block_size, 117 size_t window_size_sub_blocks, 118 int num_matched_filters, 119 size_t alignment_shift_sub_blocks, 120 float excitation_limit, 121 float smoothing_fast, 122 float smoothing_slow, 123 float matching_filter_threshold, 124 bool detect_pre_echo); 125 126 MatchedFilter() = delete; 127 MatchedFilter(const MatchedFilter&) = delete; 128 MatchedFilter& operator=(const MatchedFilter&) = delete; 129 130 ~MatchedFilter(); 131 132 // Updates the correlation with the values in the capture buffer. 133 void Update(const DownsampledRenderBuffer& render_buffer, 134 rtc::ArrayView<const float> capture, 135 bool use_slow_smoothing); 136 137 // Resets the matched filter. 138 void Reset(); 139 140 // Returns the current lag estimates. GetBestLagEstimate()141 absl::optional<const MatchedFilter::LagEstimate> GetBestLagEstimate() const { 142 return reported_lag_estimate_; 143 } 144 145 // Returns the maximum filter lag. GetMaxFilterLag()146 size_t GetMaxFilterLag() const { 147 return filters_.size() * filter_intra_lag_shift_ + filters_[0].size(); 148 } 149 150 // Log matched filter properties. 151 void LogFilterProperties(int sample_rate_hz, 152 size_t shift, 153 size_t downsampling_factor) const; 154 155 private: 156 FRIEND_TEST_ALL_PREFIXES(MatchedFilterFieldTrialTest, 157 PreEchoConfigurationTest); 158 FRIEND_TEST_ALL_PREFIXES(MatchedFilterFieldTrialTest, 159 WrongPreEchoConfigurationTest); 160 161 // Only for testing. Gets the pre echo detection configuration. GetPreEchoConfiguration()162 const PreEchoConfiguration& GetPreEchoConfiguration() const { 163 return pre_echo_config_; 164 } 165 void Dump(); 166 167 ApmDataDumper* const data_dumper_; 168 const Aec3Optimization optimization_; 169 const size_t sub_block_size_; 170 const size_t filter_intra_lag_shift_; 171 std::vector<std::vector<float>> filters_; 172 std::vector<std::vector<float>> accumulated_error_; 173 std::vector<float> instantaneous_accumulated_error_; 174 std::vector<float> scratch_memory_; 175 absl::optional<MatchedFilter::LagEstimate> reported_lag_estimate_; 176 absl::optional<size_t> winner_lag_; 177 int last_detected_best_lag_filter_ = -1; 178 std::vector<size_t> filters_offsets_; 179 const float excitation_limit_; 180 const float smoothing_fast_; 181 const float smoothing_slow_; 182 const float matching_filter_threshold_; 183 const bool detect_pre_echo_; 184 const PreEchoConfiguration pre_echo_config_; 185 }; 186 187 } // namespace webrtc 188 189 #endif // MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_ 190