xref: /aosp_15_r20/external/webrtc/modules/audio_processing/aec3/transparent_mode.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/transparent_mode.h"
12 
13 #include "rtc_base/checks.h"
14 #include "rtc_base/logging.h"
15 #include "system_wrappers/include/field_trial.h"
16 
17 namespace webrtc {
18 namespace {
19 
20 constexpr size_t kBlocksSinceConvergencedFilterInit = 10000;
21 constexpr size_t kBlocksSinceConsistentEstimateInit = 10000;
22 
DeactivateTransparentMode()23 bool DeactivateTransparentMode() {
24   return field_trial::IsEnabled("WebRTC-Aec3TransparentModeKillSwitch");
25 }
26 
ActivateTransparentModeHmm()27 bool ActivateTransparentModeHmm() {
28   return field_trial::IsEnabled("WebRTC-Aec3TransparentModeHmm");
29 }
30 
31 }  // namespace
32 
33 // Classifier that toggles transparent mode which reduces echo suppression when
34 // headsets are used.
35 class TransparentModeImpl : public TransparentMode {
36  public:
Active() const37   bool Active() const override { return transparency_activated_; }
38 
Reset()39   void Reset() override {
40     // Determines if transparent mode is used.
41     transparency_activated_ = false;
42 
43     // The estimated probability of being transparent mode.
44     prob_transparent_state_ = 0.f;
45   }
46 
Update(int filter_delay_blocks,bool any_filter_consistent,bool any_filter_converged,bool any_coarse_filter_converged,bool all_filters_diverged,bool active_render,bool saturated_capture)47   void Update(int filter_delay_blocks,
48               bool any_filter_consistent,
49               bool any_filter_converged,
50               bool any_coarse_filter_converged,
51               bool all_filters_diverged,
52               bool active_render,
53               bool saturated_capture) override {
54     // The classifier is implemented as a Hidden Markov Model (HMM) with two
55     // hidden states: "normal" and "transparent". The estimated probabilities of
56     // the two states are updated by observing filter convergence during active
57     // render. The filters are less likely to be reported as converged when
58     // there is no echo present in the microphone signal.
59 
60     // The constants have been obtained by observing active_render and
61     // any_coarse_filter_converged under varying call scenarios. They
62     // have further been hand tuned to prefer normal state during uncertain
63     // regions (to avoid echo leaks).
64 
65     // The model is only updated during active render.
66     if (!active_render)
67       return;
68 
69     // Probability of switching from one state to the other.
70     constexpr float kSwitch = 0.000001f;
71 
72     // Probability of observing converged filters in states "normal" and
73     // "transparent" during active render.
74     constexpr float kConvergedNormal = 0.01f;
75     constexpr float kConvergedTransparent = 0.001f;
76 
77     // Probability of transitioning to transparent state from normal state and
78     // transparent state respectively.
79     constexpr float kA[2] = {kSwitch, 1.f - kSwitch};
80 
81     // Probability of the two observations (converged filter or not converged
82     // filter) in normal state and transparent state respectively.
83     constexpr float kB[2][2] = {
84         {1.f - kConvergedNormal, kConvergedNormal},
85         {1.f - kConvergedTransparent, kConvergedTransparent}};
86 
87     // Probability of the two states before the update.
88     const float prob_transparent = prob_transparent_state_;
89     const float prob_normal = 1.f - prob_transparent;
90 
91     // Probability of transitioning to transparent state.
92     const float prob_transition_transparent =
93         prob_normal * kA[0] + prob_transparent * kA[1];
94     const float prob_transition_normal = 1.f - prob_transition_transparent;
95 
96     // Observed output.
97     const int out = static_cast<int>(any_coarse_filter_converged);
98 
99     // Joint probabilites of the observed output and respective states.
100     const float prob_joint_normal = prob_transition_normal * kB[0][out];
101     const float prob_joint_transparent =
102         prob_transition_transparent * kB[1][out];
103 
104     // Conditional probability of transparent state and the observed output.
105     RTC_DCHECK_GT(prob_joint_normal + prob_joint_transparent, 0.f);
106     prob_transparent_state_ =
107         prob_joint_transparent / (prob_joint_normal + prob_joint_transparent);
108 
109     // Transparent mode is only activated when its state probability is high.
110     // Dead zone between activation/deactivation thresholds to avoid switching
111     // back and forth.
112     if (prob_transparent_state_ > 0.95f) {
113       transparency_activated_ = true;
114     } else if (prob_transparent_state_ < 0.5f) {
115       transparency_activated_ = false;
116     }
117   }
118 
119  private:
120   bool transparency_activated_ = false;
121   float prob_transparent_state_ = 0.f;
122 };
123 
124 // Legacy classifier for toggling transparent mode.
125 class LegacyTransparentModeImpl : public TransparentMode {
126  public:
LegacyTransparentModeImpl(const EchoCanceller3Config & config)127   explicit LegacyTransparentModeImpl(const EchoCanceller3Config& config)
128       : linear_and_stable_echo_path_(
129             config.echo_removal_control.linear_and_stable_echo_path),
130         active_blocks_since_sane_filter_(kBlocksSinceConsistentEstimateInit),
131         non_converged_sequence_size_(kBlocksSinceConvergencedFilterInit) {}
132 
Active() const133   bool Active() const override { return transparency_activated_; }
134 
Reset()135   void Reset() override {
136     non_converged_sequence_size_ = kBlocksSinceConvergencedFilterInit;
137     diverged_sequence_size_ = 0;
138     strong_not_saturated_render_blocks_ = 0;
139     if (linear_and_stable_echo_path_) {
140       recent_convergence_during_activity_ = false;
141     }
142   }
143 
Update(int filter_delay_blocks,bool any_filter_consistent,bool any_filter_converged,bool any_coarse_filter_converged,bool all_filters_diverged,bool active_render,bool saturated_capture)144   void Update(int filter_delay_blocks,
145               bool any_filter_consistent,
146               bool any_filter_converged,
147               bool any_coarse_filter_converged,
148               bool all_filters_diverged,
149               bool active_render,
150               bool saturated_capture) override {
151     ++capture_block_counter_;
152     strong_not_saturated_render_blocks_ +=
153         active_render && !saturated_capture ? 1 : 0;
154 
155     if (any_filter_consistent && filter_delay_blocks < 5) {
156       sane_filter_observed_ = true;
157       active_blocks_since_sane_filter_ = 0;
158     } else if (active_render) {
159       ++active_blocks_since_sane_filter_;
160     }
161 
162     bool sane_filter_recently_seen;
163     if (!sane_filter_observed_) {
164       sane_filter_recently_seen =
165           capture_block_counter_ <= 5 * kNumBlocksPerSecond;
166     } else {
167       sane_filter_recently_seen =
168           active_blocks_since_sane_filter_ <= 30 * kNumBlocksPerSecond;
169     }
170 
171     if (any_filter_converged) {
172       recent_convergence_during_activity_ = true;
173       active_non_converged_sequence_size_ = 0;
174       non_converged_sequence_size_ = 0;
175       ++num_converged_blocks_;
176     } else {
177       if (++non_converged_sequence_size_ > 20 * kNumBlocksPerSecond) {
178         num_converged_blocks_ = 0;
179       }
180 
181       if (active_render &&
182           ++active_non_converged_sequence_size_ > 60 * kNumBlocksPerSecond) {
183         recent_convergence_during_activity_ = false;
184       }
185     }
186 
187     if (!all_filters_diverged) {
188       diverged_sequence_size_ = 0;
189     } else if (++diverged_sequence_size_ >= 60) {
190       // TODO(peah): Change these lines to ensure proper triggering of usable
191       // filter.
192       non_converged_sequence_size_ = kBlocksSinceConvergencedFilterInit;
193     }
194 
195     if (active_non_converged_sequence_size_ > 60 * kNumBlocksPerSecond) {
196       finite_erl_recently_detected_ = false;
197     }
198     if (num_converged_blocks_ > 50) {
199       finite_erl_recently_detected_ = true;
200     }
201 
202     if (finite_erl_recently_detected_) {
203       transparency_activated_ = false;
204     } else if (sane_filter_recently_seen &&
205                recent_convergence_during_activity_) {
206       transparency_activated_ = false;
207     } else {
208       const bool filter_should_have_converged =
209           strong_not_saturated_render_blocks_ > 6 * kNumBlocksPerSecond;
210       transparency_activated_ = filter_should_have_converged;
211     }
212   }
213 
214  private:
215   const bool linear_and_stable_echo_path_;
216   size_t capture_block_counter_ = 0;
217   bool transparency_activated_ = false;
218   size_t active_blocks_since_sane_filter_;
219   bool sane_filter_observed_ = false;
220   bool finite_erl_recently_detected_ = false;
221   size_t non_converged_sequence_size_;
222   size_t diverged_sequence_size_ = 0;
223   size_t active_non_converged_sequence_size_ = 0;
224   size_t num_converged_blocks_ = 0;
225   bool recent_convergence_during_activity_ = false;
226   size_t strong_not_saturated_render_blocks_ = 0;
227 };
228 
Create(const EchoCanceller3Config & config)229 std::unique_ptr<TransparentMode> TransparentMode::Create(
230     const EchoCanceller3Config& config) {
231   if (config.ep_strength.bounded_erl || DeactivateTransparentMode()) {
232     RTC_LOG(LS_INFO) << "AEC3 Transparent Mode: Disabled";
233     return nullptr;
234   }
235   if (ActivateTransparentModeHmm()) {
236     RTC_LOG(LS_INFO) << "AEC3 Transparent Mode: HMM";
237     return std::make_unique<TransparentModeImpl>();
238   }
239   RTC_LOG(LS_INFO) << "AEC3 Transparent Mode: Legacy";
240   return std::make_unique<LegacyTransparentModeImpl>(config);
241 }
242 
243 }  // namespace webrtc
244