xref: /aosp_15_r20/external/webrtc/modules/audio_processing/agc2/vad_wrapper_unittest.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/agc2/vad_wrapper.h"
12 
13 #include <limits>
14 #include <memory>
15 #include <tuple>
16 #include <utility>
17 #include <vector>
18 
19 #include "modules/audio_processing/agc2/agc2_common.h"
20 #include "modules/audio_processing/include/audio_frame_view.h"
21 #include "rtc_base/checks.h"
22 #include "rtc_base/gunit.h"
23 #include "rtc_base/numerics/safe_compare.h"
24 #include "test/gmock.h"
25 
26 namespace webrtc {
27 namespace {
28 
29 using ::testing::AnyNumber;
30 using ::testing::Return;
31 using ::testing::ReturnRoundRobin;
32 using ::testing::Truly;
33 
34 constexpr int kNumFramesPerSecond = 100;
35 
36 constexpr int kNoVadPeriodicReset =
37     kFrameDurationMs * (std::numeric_limits<int>::max() / kFrameDurationMs);
38 
39 constexpr int kSampleRate8kHz = 8000;
40 
41 class MockVad : public VoiceActivityDetectorWrapper::MonoVad {
42  public:
43   MOCK_METHOD(int, SampleRateHz, (), (const, override));
44   MOCK_METHOD(void, Reset, (), (override));
45   MOCK_METHOD(float, Analyze, (rtc::ArrayView<const float> frame), (override));
46 };
47 
48 // Checks that the ctor and `Initialize()` read the sample rate of the wrapped
49 // VAD.
TEST(GainController2VoiceActivityDetectorWrapper,CtorAndInitReadSampleRate)50 TEST(GainController2VoiceActivityDetectorWrapper, CtorAndInitReadSampleRate) {
51   auto vad = std::make_unique<MockVad>();
52   EXPECT_CALL(*vad, SampleRateHz)
53       .Times(2)
54       .WillRepeatedly(Return(kSampleRate8kHz));
55   EXPECT_CALL(*vad, Reset).Times(AnyNumber());
56   auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>(
57       kNoVadPeriodicReset, std::move(vad), kSampleRate8kHz);
58 }
59 
60 // Creates a `VoiceActivityDetectorWrapper` injecting a mock VAD that
61 // repeatedly returns the next value from `speech_probabilities` and that
62 // restarts from the beginning when after the last element is returned.
CreateMockVadWrapper(int vad_reset_period_ms,int sample_rate_hz,const std::vector<float> & speech_probabilities,int expected_vad_reset_calls)63 std::unique_ptr<VoiceActivityDetectorWrapper> CreateMockVadWrapper(
64     int vad_reset_period_ms,
65     int sample_rate_hz,
66     const std::vector<float>& speech_probabilities,
67     int expected_vad_reset_calls) {
68   auto vad = std::make_unique<MockVad>();
69   EXPECT_CALL(*vad, SampleRateHz)
70       .Times(AnyNumber())
71       .WillRepeatedly(Return(sample_rate_hz));
72   if (expected_vad_reset_calls >= 0) {
73     EXPECT_CALL(*vad, Reset).Times(expected_vad_reset_calls);
74   }
75   EXPECT_CALL(*vad, Analyze)
76       .Times(AnyNumber())
77       .WillRepeatedly(ReturnRoundRobin(speech_probabilities));
78   return std::make_unique<VoiceActivityDetectorWrapper>(
79       vad_reset_period_ms, std::move(vad), kSampleRate8kHz);
80 }
81 
82 // 10 ms mono frame.
83 struct FrameWithView {
84   // Ctor. Initializes the frame samples with `value`.
FrameWithViewwebrtc::__anondeb6effc0111::FrameWithView85   explicit FrameWithView(int sample_rate_hz)
86       : samples(rtc::CheckedDivExact(sample_rate_hz, kNumFramesPerSecond),
87                 0.0f),
88         channel0(samples.data()),
89         view(&channel0, /*num_channels=*/1, samples.size()) {}
90   std::vector<float> samples;
91   const float* const channel0;
92   const AudioFrameView<const float> view;
93 };
94 
95 // Checks that the expected speech probabilities are returned.
TEST(GainController2VoiceActivityDetectorWrapper,CheckSpeechProbabilities)96 TEST(GainController2VoiceActivityDetectorWrapper, CheckSpeechProbabilities) {
97   const std::vector<float> speech_probabilities{0.709f, 0.484f, 0.882f, 0.167f,
98                                                 0.44f,  0.525f, 0.858f, 0.314f,
99                                                 0.653f, 0.965f, 0.413f, 0.0f};
100   auto vad_wrapper = CreateMockVadWrapper(kNoVadPeriodicReset, kSampleRate8kHz,
101                                           speech_probabilities,
102                                           /*expected_vad_reset_calls=*/1);
103   FrameWithView frame(kSampleRate8kHz);
104   for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) {
105     SCOPED_TRACE(i);
106     EXPECT_EQ(speech_probabilities[i], vad_wrapper->Analyze(frame.view));
107   }
108 }
109 
110 // Checks that the VAD is not periodically reset.
TEST(GainController2VoiceActivityDetectorWrapper,VadNoPeriodicReset)111 TEST(GainController2VoiceActivityDetectorWrapper, VadNoPeriodicReset) {
112   constexpr int kNumFrames = 19;
113   auto vad_wrapper = CreateMockVadWrapper(kNoVadPeriodicReset, kSampleRate8kHz,
114                                           /*speech_probabilities=*/{1.0f},
115                                           /*expected_vad_reset_calls=*/1);
116   FrameWithView frame(kSampleRate8kHz);
117   for (int i = 0; i < kNumFrames; ++i) {
118     vad_wrapper->Analyze(frame.view);
119   }
120 }
121 
122 class VadPeriodResetParametrization
123     : public ::testing::TestWithParam<std::tuple<int, int>> {
124  protected:
num_frames() const125   int num_frames() const { return std::get<0>(GetParam()); }
vad_reset_period_frames() const126   int vad_reset_period_frames() const { return std::get<1>(GetParam()); }
127 };
128 
129 // Checks that the VAD is periodically reset with the expected period.
TEST_P(VadPeriodResetParametrization,VadPeriodicReset)130 TEST_P(VadPeriodResetParametrization, VadPeriodicReset) {
131   auto vad_wrapper = CreateMockVadWrapper(
132       /*vad_reset_period_ms=*/vad_reset_period_frames() * kFrameDurationMs,
133       kSampleRate8kHz,
134       /*speech_probabilities=*/{1.0f},
135       /*expected_vad_reset_calls=*/1 +
136           num_frames() / vad_reset_period_frames());
137   FrameWithView frame(kSampleRate8kHz);
138   for (int i = 0; i < num_frames(); ++i) {
139     vad_wrapper->Analyze(frame.view);
140   }
141 }
142 
143 INSTANTIATE_TEST_SUITE_P(GainController2VoiceActivityDetectorWrapper,
144                          VadPeriodResetParametrization,
145                          ::testing::Combine(::testing::Values(1, 19, 123),
146                                             ::testing::Values(2, 5, 20, 53)));
147 
148 class VadResamplingParametrization
149     : public ::testing::TestWithParam<std::tuple<int, int>> {
150  protected:
input_sample_rate_hz() const151   int input_sample_rate_hz() const { return std::get<0>(GetParam()); }
vad_sample_rate_hz() const152   int vad_sample_rate_hz() const { return std::get<1>(GetParam()); }
153 };
154 
155 // Checks that regardless of the input audio sample rate, the wrapped VAD
156 // analyzes frames having the expected size, that is according to its internal
157 // sample rate.
TEST_P(VadResamplingParametrization,CheckResampledFrameSize)158 TEST_P(VadResamplingParametrization, CheckResampledFrameSize) {
159   auto vad = std::make_unique<MockVad>();
160   EXPECT_CALL(*vad, SampleRateHz)
161       .Times(AnyNumber())
162       .WillRepeatedly(Return(vad_sample_rate_hz()));
163   EXPECT_CALL(*vad, Reset).Times(1);
164   EXPECT_CALL(*vad, Analyze(Truly([this](rtc::ArrayView<const float> frame) {
165     return rtc::SafeEq(frame.size(), rtc::CheckedDivExact(vad_sample_rate_hz(),
166                                                           kNumFramesPerSecond));
167   }))).Times(1);
168   auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>(
169       kNoVadPeriodicReset, std::move(vad), input_sample_rate_hz());
170   FrameWithView frame(input_sample_rate_hz());
171   vad_wrapper->Analyze(frame.view);
172 }
173 
174 INSTANTIATE_TEST_SUITE_P(
175     GainController2VoiceActivityDetectorWrapper,
176     VadResamplingParametrization,
177     ::testing::Combine(::testing::Values(8000, 16000, 44100, 48000),
178                        ::testing::Values(6000, 8000, 12000, 16000, 24000)));
179 
180 }  // namespace
181 }  // namespace webrtc
182