1 /*
2 * Copyright (c) 2015 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include <memory>
12
13 #include "absl/memory/memory.h"
14 #include "api/audio/audio_frame.h"
15 #include "api/audio_codecs/audio_decoder.h"
16 #include "api/audio_codecs/builtin_audio_decoder_factory.h"
17 #include "api/neteq/neteq.h"
18 #include "modules/audio_coding/neteq/default_neteq_factory.h"
19 #include "modules/audio_coding/neteq/tools/rtp_generator.h"
20 #include "system_wrappers/include/clock.h"
21 #include "test/audio_decoder_proxy_factory.h"
22 #include "test/gmock.h"
23
24 namespace webrtc {
25 namespace test {
26
27 namespace {
28
CreateNetEq(const NetEq::Config & config,Clock * clock,const rtc::scoped_refptr<AudioDecoderFactory> & decoder_factory)29 std::unique_ptr<NetEq> CreateNetEq(
30 const NetEq::Config& config,
31 Clock* clock,
32 const rtc::scoped_refptr<AudioDecoderFactory>& decoder_factory) {
33 return DefaultNetEqFactory().CreateNetEq(config, decoder_factory, clock);
34 }
35
36 } // namespace
37
38 using ::testing::_;
39 using ::testing::Return;
40 using ::testing::SetArgPointee;
41
42 class MockAudioDecoder final : public AudioDecoder {
43 public:
44 static const int kPacketDuration = 960; // 48 kHz * 20 ms
45
MockAudioDecoder(int sample_rate_hz,size_t num_channels)46 MockAudioDecoder(int sample_rate_hz, size_t num_channels)
47 : sample_rate_hz_(sample_rate_hz),
48 num_channels_(num_channels),
49 fec_enabled_(false) {}
~MockAudioDecoder()50 ~MockAudioDecoder() override { Die(); }
51 MOCK_METHOD(void, Die, ());
52
53 MOCK_METHOD(void, Reset, (), (override));
54
55 class MockFrame : public AudioDecoder::EncodedAudioFrame {
56 public:
MockFrame(size_t num_channels)57 MockFrame(size_t num_channels) : num_channels_(num_channels) {}
58
Duration() const59 size_t Duration() const override { return kPacketDuration; }
60
Decode(rtc::ArrayView<int16_t> decoded) const61 absl::optional<DecodeResult> Decode(
62 rtc::ArrayView<int16_t> decoded) const override {
63 const size_t output_size =
64 sizeof(int16_t) * kPacketDuration * num_channels_;
65 if (decoded.size() >= output_size) {
66 memset(decoded.data(), 0,
67 sizeof(int16_t) * kPacketDuration * num_channels_);
68 return DecodeResult{kPacketDuration * num_channels_, kSpeech};
69 } else {
70 ADD_FAILURE() << "Expected decoded.size() to be >= output_size ("
71 << decoded.size() << " vs. " << output_size << ")";
72 return absl::nullopt;
73 }
74 }
75
76 private:
77 const size_t num_channels_;
78 };
79
ParsePayload(rtc::Buffer && payload,uint32_t timestamp)80 std::vector<ParseResult> ParsePayload(rtc::Buffer&& payload,
81 uint32_t timestamp) override {
82 std::vector<ParseResult> results;
83 if (fec_enabled_) {
84 std::unique_ptr<MockFrame> fec_frame(new MockFrame(num_channels_));
85 results.emplace_back(timestamp - kPacketDuration, 1,
86 std::move(fec_frame));
87 }
88
89 std::unique_ptr<MockFrame> frame(new MockFrame(num_channels_));
90 results.emplace_back(timestamp, 0, std::move(frame));
91 return results;
92 }
93
PacketDuration(const uint8_t * encoded,size_t encoded_len) const94 int PacketDuration(const uint8_t* encoded,
95 size_t encoded_len) const override {
96 ADD_FAILURE() << "Since going through ParsePayload, PacketDuration should "
97 "never get called.";
98 return kPacketDuration;
99 }
100
PacketHasFec(const uint8_t * encoded,size_t encoded_len) const101 bool PacketHasFec(const uint8_t* encoded, size_t encoded_len) const override {
102 ADD_FAILURE() << "Since going through ParsePayload, PacketHasFec should "
103 "never get called.";
104 return fec_enabled_;
105 }
106
SampleRateHz() const107 int SampleRateHz() const override { return sample_rate_hz_; }
108
Channels() const109 size_t Channels() const override { return num_channels_; }
110
set_fec_enabled(bool enable_fec)111 void set_fec_enabled(bool enable_fec) { fec_enabled_ = enable_fec; }
112
fec_enabled() const113 bool fec_enabled() const { return fec_enabled_; }
114
115 protected:
DecodeInternal(const uint8_t * encoded,size_t encoded_len,int sample_rate_hz,int16_t * decoded,SpeechType * speech_type)116 int DecodeInternal(const uint8_t* encoded,
117 size_t encoded_len,
118 int sample_rate_hz,
119 int16_t* decoded,
120 SpeechType* speech_type) override {
121 ADD_FAILURE() << "Since going through ParsePayload, DecodeInternal should "
122 "never get called.";
123 return -1;
124 }
125
126 private:
127 const int sample_rate_hz_;
128 const size_t num_channels_;
129 bool fec_enabled_;
130 };
131
132 class NetEqNetworkStatsTest {
133 public:
134 static const int kPayloadSizeByte = 30;
135 static const int kFrameSizeMs = 20;
136 static const uint8_t kPayloadType = 95;
137 static const int kOutputLengthMs = 10;
138
139 enum logic {
140 kIgnore,
141 kEqual,
142 kSmallerThan,
143 kLargerThan,
144 };
145
146 struct NetEqNetworkStatsCheck {
147 logic current_buffer_size_ms;
148 logic preferred_buffer_size_ms;
149 logic jitter_peaks_found;
150 logic packet_loss_rate;
151 logic expand_rate;
152 logic speech_expand_rate;
153 logic preemptive_rate;
154 logic accelerate_rate;
155 logic secondary_decoded_rate;
156 logic secondary_discarded_rate;
157 logic added_zero_samples;
158 NetEqNetworkStatistics stats_ref;
159 };
160
NetEqNetworkStatsTest(const SdpAudioFormat & format,MockAudioDecoder * decoder)161 NetEqNetworkStatsTest(const SdpAudioFormat& format, MockAudioDecoder* decoder)
162 : decoder_(decoder),
163 decoder_factory_(
164 rtc::make_ref_counted<AudioDecoderProxyFactory>(decoder)),
165 samples_per_ms_(format.clockrate_hz / 1000),
166 frame_size_samples_(kFrameSizeMs * samples_per_ms_),
167 rtp_generator_(new RtpGenerator(samples_per_ms_)),
168 last_lost_time_(0),
169 packet_loss_interval_(0xffffffff) {
170 NetEq::Config config;
171 config.sample_rate_hz = format.clockrate_hz;
172 neteq_ = CreateNetEq(config, Clock::GetRealTimeClock(), decoder_factory_);
173 neteq_->RegisterPayloadType(kPayloadType, format);
174 }
175
Lost(uint32_t send_time)176 bool Lost(uint32_t send_time) {
177 if (send_time - last_lost_time_ >= packet_loss_interval_) {
178 last_lost_time_ = send_time;
179 return true;
180 }
181 return false;
182 }
183
SetPacketLossRate(double loss_rate)184 void SetPacketLossRate(double loss_rate) {
185 packet_loss_interval_ =
186 (loss_rate >= 1e-3 ? static_cast<double>(kFrameSizeMs) / loss_rate
187 : 0xffffffff);
188 }
189
190 // `stats_ref`
191 // expects.x = -1, do not care
192 // expects.x = 0, 'x' in current stats should equal 'x' in `stats_ref`
193 // expects.x = 1, 'x' in current stats should < 'x' in `stats_ref`
194 // expects.x = 2, 'x' in current stats should > 'x' in `stats_ref`
CheckNetworkStatistics(NetEqNetworkStatsCheck expects)195 void CheckNetworkStatistics(NetEqNetworkStatsCheck expects) {
196 NetEqNetworkStatistics stats;
197 neteq_->NetworkStatistics(&stats);
198
199 #define CHECK_NETEQ_NETWORK_STATS(x) \
200 switch (expects.x) { \
201 case kEqual: \
202 EXPECT_EQ(stats.x, expects.stats_ref.x); \
203 break; \
204 case kSmallerThan: \
205 EXPECT_LT(stats.x, expects.stats_ref.x); \
206 break; \
207 case kLargerThan: \
208 EXPECT_GT(stats.x, expects.stats_ref.x); \
209 break; \
210 default: \
211 break; \
212 }
213
214 CHECK_NETEQ_NETWORK_STATS(current_buffer_size_ms);
215 CHECK_NETEQ_NETWORK_STATS(preferred_buffer_size_ms);
216 CHECK_NETEQ_NETWORK_STATS(jitter_peaks_found);
217 CHECK_NETEQ_NETWORK_STATS(expand_rate);
218 CHECK_NETEQ_NETWORK_STATS(speech_expand_rate);
219 CHECK_NETEQ_NETWORK_STATS(preemptive_rate);
220 CHECK_NETEQ_NETWORK_STATS(accelerate_rate);
221 CHECK_NETEQ_NETWORK_STATS(secondary_decoded_rate);
222 CHECK_NETEQ_NETWORK_STATS(secondary_discarded_rate);
223
224 #undef CHECK_NETEQ_NETWORK_STATS
225 }
226
RunTest(int num_loops,NetEqNetworkStatsCheck expects)227 void RunTest(int num_loops, NetEqNetworkStatsCheck expects) {
228 uint32_t time_now;
229 uint32_t next_send_time;
230
231 // Initiate `last_lost_time_`.
232 time_now = next_send_time = last_lost_time_ = rtp_generator_->GetRtpHeader(
233 kPayloadType, frame_size_samples_, &rtp_header_);
234 for (int k = 0; k < num_loops; ++k) {
235 // Delay by one frame such that the FEC can come in.
236 while (time_now + kFrameSizeMs >= next_send_time) {
237 next_send_time = rtp_generator_->GetRtpHeader(
238 kPayloadType, frame_size_samples_, &rtp_header_);
239 if (!Lost(next_send_time)) {
240 static const uint8_t payload[kPayloadSizeByte] = {0};
241 ASSERT_EQ(NetEq::kOK, neteq_->InsertPacket(rtp_header_, payload));
242 }
243 }
244 bool muted = true;
245 EXPECT_EQ(NetEq::kOK, neteq_->GetAudio(&output_frame_, &muted));
246 ASSERT_FALSE(muted);
247 EXPECT_EQ(decoder_->Channels(), output_frame_.num_channels_);
248 EXPECT_EQ(static_cast<size_t>(kOutputLengthMs * samples_per_ms_),
249 output_frame_.samples_per_channel_);
250 EXPECT_EQ(48000, neteq_->last_output_sample_rate_hz());
251
252 time_now += kOutputLengthMs;
253 }
254 CheckNetworkStatistics(expects);
255 neteq_->FlushBuffers();
256 }
257
DecodeFecTest()258 void DecodeFecTest() {
259 decoder_->set_fec_enabled(false);
260 NetEqNetworkStatsCheck expects = {kIgnore, // current_buffer_size_ms
261 kIgnore, // preferred_buffer_size_ms
262 kIgnore, // jitter_peaks_found
263 kEqual, // packet_loss_rate
264 kEqual, // expand_rate
265 kEqual, // voice_expand_rate
266 kIgnore, // preemptive_rate
267 kEqual, // accelerate_rate
268 kEqual, // decoded_fec_rate
269 kEqual, // discarded_fec_rate
270 kEqual, // added_zero_samples
271 {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
272 RunTest(50, expects);
273
274 // Next we introduce packet losses.
275 SetPacketLossRate(0.1);
276 expects.stats_ref.expand_rate = expects.stats_ref.speech_expand_rate = 898;
277 RunTest(50, expects);
278
279 // Next we enable FEC.
280 decoder_->set_fec_enabled(true);
281 // If FEC fills in the lost packets, no packet loss will be counted.
282 expects.stats_ref.expand_rate = expects.stats_ref.speech_expand_rate = 0;
283 expects.stats_ref.secondary_decoded_rate = 2006;
284 expects.stats_ref.secondary_discarded_rate = 14336;
285 RunTest(50, expects);
286 }
287
NoiseExpansionTest()288 void NoiseExpansionTest() {
289 NetEqNetworkStatsCheck expects = {kIgnore, // current_buffer_size_ms
290 kIgnore, // preferred_buffer_size_ms
291 kIgnore, // jitter_peaks_found
292 kEqual, // packet_loss_rate
293 kEqual, // expand_rate
294 kEqual, // speech_expand_rate
295 kIgnore, // preemptive_rate
296 kEqual, // accelerate_rate
297 kEqual, // decoded_fec_rate
298 kEqual, // discard_fec_rate
299 kEqual, // added_zero_samples
300 {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
301 RunTest(50, expects);
302
303 SetPacketLossRate(1);
304 expects.stats_ref.expand_rate = 16384;
305 expects.stats_ref.speech_expand_rate = 5324;
306 RunTest(10, expects); // Lost 10 * 20ms in a row.
307 }
308
309 private:
310 MockAudioDecoder* decoder_;
311 rtc::scoped_refptr<AudioDecoderProxyFactory> decoder_factory_;
312 std::unique_ptr<NetEq> neteq_;
313
314 const int samples_per_ms_;
315 const size_t frame_size_samples_;
316 std::unique_ptr<RtpGenerator> rtp_generator_;
317 RTPHeader rtp_header_;
318 uint32_t last_lost_time_;
319 uint32_t packet_loss_interval_;
320 AudioFrame output_frame_;
321 };
322
TEST(NetEqNetworkStatsTest,DecodeFec)323 TEST(NetEqNetworkStatsTest, DecodeFec) {
324 MockAudioDecoder decoder(48000, 1);
325 NetEqNetworkStatsTest test(SdpAudioFormat("opus", 48000, 2), &decoder);
326 test.DecodeFecTest();
327 EXPECT_CALL(decoder, Die()).Times(1);
328 }
329
TEST(NetEqNetworkStatsTest,StereoDecodeFec)330 TEST(NetEqNetworkStatsTest, StereoDecodeFec) {
331 MockAudioDecoder decoder(48000, 2);
332 NetEqNetworkStatsTest test(SdpAudioFormat("opus", 48000, 2), &decoder);
333 test.DecodeFecTest();
334 EXPECT_CALL(decoder, Die()).Times(1);
335 }
336
TEST(NetEqNetworkStatsTest,NoiseExpansionTest)337 TEST(NetEqNetworkStatsTest, NoiseExpansionTest) {
338 MockAudioDecoder decoder(48000, 1);
339 NetEqNetworkStatsTest test(SdpAudioFormat("opus", 48000, 2), &decoder);
340 test.NoiseExpansionTest();
341 EXPECT_CALL(decoder, Die()).Times(1);
342 }
343
344 } // namespace test
345 } // namespace webrtc
346