1 /*
2 * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
12
13 #include <array>
14 #include <memory>
15 #include <vector>
16
17 #include "api/array_view.h"
18 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
19 #include "modules/audio_processing/test/performance_timer.h"
20 #include "rtc_base/checks.h"
21 #include "rtc_base/logging.h"
22 #include "test/gtest.h"
23 #include "third_party/rnnoise/src/rnn_vad_weights.h"
24
25 namespace webrtc {
26 namespace rnn_vad {
27 namespace {
28
TestGatedRecurrentLayer(GatedRecurrentLayer & gru,rtc::ArrayView<const float> input_sequence,rtc::ArrayView<const float> expected_output_sequence)29 void TestGatedRecurrentLayer(
30 GatedRecurrentLayer& gru,
31 rtc::ArrayView<const float> input_sequence,
32 rtc::ArrayView<const float> expected_output_sequence) {
33 const int input_sequence_length = rtc::CheckedDivExact(
34 rtc::dchecked_cast<int>(input_sequence.size()), gru.input_size());
35 const int output_sequence_length = rtc::CheckedDivExact(
36 rtc::dchecked_cast<int>(expected_output_sequence.size()), gru.size());
37 ASSERT_EQ(input_sequence_length, output_sequence_length)
38 << "The test data length is invalid.";
39 // Feed the GRU layer and check the output at every step.
40 gru.Reset();
41 for (int i = 0; i < input_sequence_length; ++i) {
42 SCOPED_TRACE(i);
43 gru.ComputeOutput(
44 input_sequence.subview(i * gru.input_size(), gru.input_size()));
45 const auto expected_output =
46 expected_output_sequence.subview(i * gru.size(), gru.size());
47 ExpectNearAbsolute(expected_output, gru, 3e-6f);
48 }
49 }
50
51 // Gated recurrent units layer test data.
52 constexpr int kGruInputSize = 5;
53 constexpr int kGruOutputSize = 4;
54 constexpr std::array<int8_t, 12> kGruBias = {96, -99, -81, -114, 49, 119,
55 -118, 68, -76, 91, 121, 125};
56 constexpr std::array<int8_t, 60> kGruWeights = {
57 // Input 0.
58 124, 9, 1, 116, // Update.
59 -66, -21, -118, -110, // Reset.
60 104, 75, -23, -51, // Output.
61 // Input 1.
62 -72, -111, 47, 93, // Update.
63 77, -98, 41, -8, // Reset.
64 40, -23, -43, -107, // Output.
65 // Input 2.
66 9, -73, 30, -32, // Update.
67 -2, 64, -26, 91, // Reset.
68 -48, -24, -28, -104, // Output.
69 // Input 3.
70 74, -46, 116, 15, // Update.
71 32, 52, -126, -38, // Reset.
72 -121, 12, -16, 110, // Output.
73 // Input 4.
74 -95, 66, -103, -35, // Update.
75 -38, 3, -126, -61, // Reset.
76 28, 98, -117, -43 // Output.
77 };
78 constexpr std::array<int8_t, 48> kGruRecurrentWeights = {
79 // Output 0.
80 -3, 87, 50, 51, // Update.
81 -22, 27, -39, 62, // Reset.
82 31, -83, -52, -48, // Output.
83 // Output 1.
84 -6, 83, -19, 104, // Update.
85 105, 48, 23, 68, // Reset.
86 23, 40, 7, -120, // Output.
87 // Output 2.
88 64, -62, 117, 85, // Update.
89 51, -43, 54, -105, // Reset.
90 120, 56, -128, -107, // Output.
91 // Output 3.
92 39, 50, -17, -47, // Update.
93 -117, 14, 108, 12, // Reset.
94 -7, -72, 103, -87, // Output.
95 };
96 constexpr std::array<float, 20> kGruInputSequence = {
97 0.89395463f, 0.93224651f, 0.55788344f, 0.32341808f, 0.93355054f,
98 0.13475326f, 0.97370994f, 0.14253306f, 0.93710381f, 0.76093364f,
99 0.65780413f, 0.41657975f, 0.49403164f, 0.46843281f, 0.75138855f,
100 0.24517593f, 0.47657707f, 0.57064998f, 0.435184f, 0.19319285f};
101 constexpr std::array<float, 16> kGruExpectedOutputSequence = {
102 0.0239123f, 0.5773077f, 0.f, 0.f,
103 0.01282811f, 0.64330572f, 0.f, 0.04863098f,
104 0.00781069f, 0.75267816f, 0.f, 0.02579715f,
105 0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f};
106
107 class RnnGruParametrization
108 : public ::testing::TestWithParam<AvailableCpuFeatures> {};
109
110 // Checks that the output of a GRU layer is within tolerance given test input
111 // data.
TEST_P(RnnGruParametrization,CheckGatedRecurrentLayer)112 TEST_P(RnnGruParametrization, CheckGatedRecurrentLayer) {
113 GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
114 kGruRecurrentWeights,
115 /*cpu_features=*/GetParam(),
116 /*layer_name=*/"GRU");
117 TestGatedRecurrentLayer(gru, kGruInputSequence, kGruExpectedOutputSequence);
118 }
119
TEST_P(RnnGruParametrization,DISABLED_BenchmarkGatedRecurrentLayer)120 TEST_P(RnnGruParametrization, DISABLED_BenchmarkGatedRecurrentLayer) {
121 // Prefetch test data.
122 std::unique_ptr<FileReader> reader = CreateGruInputReader();
123 std::vector<float> gru_input_sequence(reader->size());
124 reader->ReadChunk(gru_input_sequence);
125
126 using ::rnnoise::kHiddenGruBias;
127 using ::rnnoise::kHiddenGruRecurrentWeights;
128 using ::rnnoise::kHiddenGruWeights;
129 using ::rnnoise::kHiddenLayerOutputSize;
130 using ::rnnoise::kInputLayerOutputSize;
131
132 GatedRecurrentLayer gru(kInputLayerOutputSize, kHiddenLayerOutputSize,
133 kHiddenGruBias, kHiddenGruWeights,
134 kHiddenGruRecurrentWeights,
135 /*cpu_features=*/GetParam(),
136 /*layer_name=*/"GRU");
137
138 rtc::ArrayView<const float> input_sequence(gru_input_sequence);
139 ASSERT_EQ(input_sequence.size() % kInputLayerOutputSize,
140 static_cast<size_t>(0));
141 const int input_sequence_length =
142 input_sequence.size() / kInputLayerOutputSize;
143
144 constexpr int kNumTests = 100;
145 ::webrtc::test::PerformanceTimer perf_timer(kNumTests);
146 for (int k = 0; k < kNumTests; ++k) {
147 perf_timer.StartTimer();
148 for (int i = 0; i < input_sequence_length; ++i) {
149 gru.ComputeOutput(
150 input_sequence.subview(i * gru.input_size(), gru.input_size()));
151 }
152 perf_timer.StopTimer();
153 }
154 RTC_LOG(LS_INFO) << (perf_timer.GetDurationAverage() / 1000) << " +/- "
155 << (perf_timer.GetDurationStandardDeviation() / 1000)
156 << " ms";
157 }
158
159 // Finds the relevant CPU features combinations to test.
GetCpuFeaturesToTest()160 std::vector<AvailableCpuFeatures> GetCpuFeaturesToTest() {
161 std::vector<AvailableCpuFeatures> v;
162 v.push_back(NoAvailableCpuFeatures());
163 AvailableCpuFeatures available = GetAvailableCpuFeatures();
164 if (available.sse2) {
165 v.push_back({/*sse2=*/true, /*avx2=*/false, /*neon=*/false});
166 }
167 if (available.avx2) {
168 v.push_back({/*sse2=*/false, /*avx2=*/true, /*neon=*/false});
169 }
170 if (available.neon) {
171 v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/true});
172 }
173 return v;
174 }
175
176 INSTANTIATE_TEST_SUITE_P(
177 RnnVadTest,
178 RnnGruParametrization,
179 ::testing::ValuesIn(GetCpuFeaturesToTest()),
__anon1fc98c130202(const ::testing::TestParamInfo<AvailableCpuFeatures>& info) 180 [](const ::testing::TestParamInfo<AvailableCpuFeatures>& info) {
181 return info.param.ToString();
182 });
183
184 } // namespace
185 } // namespace rnn_vad
186 } // namespace webrtc
187