1 /*
2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/agc2/vad_wrapper.h"
12
13 #include <limits>
14 #include <memory>
15 #include <tuple>
16 #include <utility>
17 #include <vector>
18
19 #include "modules/audio_processing/agc2/agc2_common.h"
20 #include "modules/audio_processing/include/audio_frame_view.h"
21 #include "rtc_base/checks.h"
22 #include "rtc_base/gunit.h"
23 #include "rtc_base/numerics/safe_compare.h"
24 #include "test/gmock.h"
25
26 namespace webrtc {
27 namespace {
28
29 using ::testing::AnyNumber;
30 using ::testing::Return;
31 using ::testing::ReturnRoundRobin;
32 using ::testing::Truly;
33
34 constexpr int kNumFramesPerSecond = 100;
35
36 constexpr int kNoVadPeriodicReset =
37 kFrameDurationMs * (std::numeric_limits<int>::max() / kFrameDurationMs);
38
39 constexpr int kSampleRate8kHz = 8000;
40
41 class MockVad : public VoiceActivityDetectorWrapper::MonoVad {
42 public:
43 MOCK_METHOD(int, SampleRateHz, (), (const, override));
44 MOCK_METHOD(void, Reset, (), (override));
45 MOCK_METHOD(float, Analyze, (rtc::ArrayView<const float> frame), (override));
46 };
47
48 // Checks that the ctor and `Initialize()` read the sample rate of the wrapped
49 // VAD.
TEST(GainController2VoiceActivityDetectorWrapper,CtorAndInitReadSampleRate)50 TEST(GainController2VoiceActivityDetectorWrapper, CtorAndInitReadSampleRate) {
51 auto vad = std::make_unique<MockVad>();
52 EXPECT_CALL(*vad, SampleRateHz)
53 .Times(2)
54 .WillRepeatedly(Return(kSampleRate8kHz));
55 EXPECT_CALL(*vad, Reset).Times(AnyNumber());
56 auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>(
57 kNoVadPeriodicReset, std::move(vad), kSampleRate8kHz);
58 }
59
60 // Creates a `VoiceActivityDetectorWrapper` injecting a mock VAD that
61 // repeatedly returns the next value from `speech_probabilities` and that
62 // restarts from the beginning when after the last element is returned.
CreateMockVadWrapper(int vad_reset_period_ms,int sample_rate_hz,const std::vector<float> & speech_probabilities,int expected_vad_reset_calls)63 std::unique_ptr<VoiceActivityDetectorWrapper> CreateMockVadWrapper(
64 int vad_reset_period_ms,
65 int sample_rate_hz,
66 const std::vector<float>& speech_probabilities,
67 int expected_vad_reset_calls) {
68 auto vad = std::make_unique<MockVad>();
69 EXPECT_CALL(*vad, SampleRateHz)
70 .Times(AnyNumber())
71 .WillRepeatedly(Return(sample_rate_hz));
72 if (expected_vad_reset_calls >= 0) {
73 EXPECT_CALL(*vad, Reset).Times(expected_vad_reset_calls);
74 }
75 EXPECT_CALL(*vad, Analyze)
76 .Times(AnyNumber())
77 .WillRepeatedly(ReturnRoundRobin(speech_probabilities));
78 return std::make_unique<VoiceActivityDetectorWrapper>(
79 vad_reset_period_ms, std::move(vad), kSampleRate8kHz);
80 }
81
82 // 10 ms mono frame.
83 struct FrameWithView {
84 // Ctor. Initializes the frame samples with `value`.
FrameWithViewwebrtc::__anondeb6effc0111::FrameWithView85 explicit FrameWithView(int sample_rate_hz)
86 : samples(rtc::CheckedDivExact(sample_rate_hz, kNumFramesPerSecond),
87 0.0f),
88 channel0(samples.data()),
89 view(&channel0, /*num_channels=*/1, samples.size()) {}
90 std::vector<float> samples;
91 const float* const channel0;
92 const AudioFrameView<const float> view;
93 };
94
95 // Checks that the expected speech probabilities are returned.
TEST(GainController2VoiceActivityDetectorWrapper,CheckSpeechProbabilities)96 TEST(GainController2VoiceActivityDetectorWrapper, CheckSpeechProbabilities) {
97 const std::vector<float> speech_probabilities{0.709f, 0.484f, 0.882f, 0.167f,
98 0.44f, 0.525f, 0.858f, 0.314f,
99 0.653f, 0.965f, 0.413f, 0.0f};
100 auto vad_wrapper = CreateMockVadWrapper(kNoVadPeriodicReset, kSampleRate8kHz,
101 speech_probabilities,
102 /*expected_vad_reset_calls=*/1);
103 FrameWithView frame(kSampleRate8kHz);
104 for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) {
105 SCOPED_TRACE(i);
106 EXPECT_EQ(speech_probabilities[i], vad_wrapper->Analyze(frame.view));
107 }
108 }
109
110 // Checks that the VAD is not periodically reset.
TEST(GainController2VoiceActivityDetectorWrapper,VadNoPeriodicReset)111 TEST(GainController2VoiceActivityDetectorWrapper, VadNoPeriodicReset) {
112 constexpr int kNumFrames = 19;
113 auto vad_wrapper = CreateMockVadWrapper(kNoVadPeriodicReset, kSampleRate8kHz,
114 /*speech_probabilities=*/{1.0f},
115 /*expected_vad_reset_calls=*/1);
116 FrameWithView frame(kSampleRate8kHz);
117 for (int i = 0; i < kNumFrames; ++i) {
118 vad_wrapper->Analyze(frame.view);
119 }
120 }
121
122 class VadPeriodResetParametrization
123 : public ::testing::TestWithParam<std::tuple<int, int>> {
124 protected:
num_frames() const125 int num_frames() const { return std::get<0>(GetParam()); }
vad_reset_period_frames() const126 int vad_reset_period_frames() const { return std::get<1>(GetParam()); }
127 };
128
129 // Checks that the VAD is periodically reset with the expected period.
TEST_P(VadPeriodResetParametrization,VadPeriodicReset)130 TEST_P(VadPeriodResetParametrization, VadPeriodicReset) {
131 auto vad_wrapper = CreateMockVadWrapper(
132 /*vad_reset_period_ms=*/vad_reset_period_frames() * kFrameDurationMs,
133 kSampleRate8kHz,
134 /*speech_probabilities=*/{1.0f},
135 /*expected_vad_reset_calls=*/1 +
136 num_frames() / vad_reset_period_frames());
137 FrameWithView frame(kSampleRate8kHz);
138 for (int i = 0; i < num_frames(); ++i) {
139 vad_wrapper->Analyze(frame.view);
140 }
141 }
142
143 INSTANTIATE_TEST_SUITE_P(GainController2VoiceActivityDetectorWrapper,
144 VadPeriodResetParametrization,
145 ::testing::Combine(::testing::Values(1, 19, 123),
146 ::testing::Values(2, 5, 20, 53)));
147
148 class VadResamplingParametrization
149 : public ::testing::TestWithParam<std::tuple<int, int>> {
150 protected:
input_sample_rate_hz() const151 int input_sample_rate_hz() const { return std::get<0>(GetParam()); }
vad_sample_rate_hz() const152 int vad_sample_rate_hz() const { return std::get<1>(GetParam()); }
153 };
154
155 // Checks that regardless of the input audio sample rate, the wrapped VAD
156 // analyzes frames having the expected size, that is according to its internal
157 // sample rate.
TEST_P(VadResamplingParametrization,CheckResampledFrameSize)158 TEST_P(VadResamplingParametrization, CheckResampledFrameSize) {
159 auto vad = std::make_unique<MockVad>();
160 EXPECT_CALL(*vad, SampleRateHz)
161 .Times(AnyNumber())
162 .WillRepeatedly(Return(vad_sample_rate_hz()));
163 EXPECT_CALL(*vad, Reset).Times(1);
164 EXPECT_CALL(*vad, Analyze(Truly([this](rtc::ArrayView<const float> frame) {
165 return rtc::SafeEq(frame.size(), rtc::CheckedDivExact(vad_sample_rate_hz(),
166 kNumFramesPerSecond));
167 }))).Times(1);
168 auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>(
169 kNoVadPeriodicReset, std::move(vad), input_sample_rate_hz());
170 FrameWithView frame(input_sample_rate_hz());
171 vad_wrapper->Analyze(frame.view);
172 }
173
174 INSTANTIATE_TEST_SUITE_P(
175 GainController2VoiceActivityDetectorWrapper,
176 VadResamplingParametrization,
177 ::testing::Combine(::testing::Values(8000, 16000, 44100, 48000),
178 ::testing::Values(6000, 8000, 12000, 16000, 24000)));
179
180 } // namespace
181 } // namespace webrtc
182