1 /*
2 * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
12
13 #include "rtc_base/checks.h"
14 #include "rtc_base/numerics/safe_conversions.h"
15 #include "third_party/rnnoise/src/rnn_activations.h"
16 #include "third_party/rnnoise/src/rnn_vad_weights.h"
17
18 namespace webrtc {
19 namespace rnn_vad {
20 namespace {
21
22 constexpr int kNumGruGates = 3; // Update, reset, output.
23
PreprocessGruTensor(rtc::ArrayView<const int8_t> tensor_src,int output_size)24 std::vector<float> PreprocessGruTensor(rtc::ArrayView<const int8_t> tensor_src,
25 int output_size) {
26 // Transpose, cast and scale.
27 // `n` is the size of the first dimension of the 3-dim tensor `weights`.
28 const int n = rtc::CheckedDivExact(rtc::dchecked_cast<int>(tensor_src.size()),
29 output_size * kNumGruGates);
30 const int stride_src = kNumGruGates * output_size;
31 const int stride_dst = n * output_size;
32 std::vector<float> tensor_dst(tensor_src.size());
33 for (int g = 0; g < kNumGruGates; ++g) {
34 for (int o = 0; o < output_size; ++o) {
35 for (int i = 0; i < n; ++i) {
36 tensor_dst[g * stride_dst + o * n + i] =
37 ::rnnoise::kWeightsScale *
38 static_cast<float>(
39 tensor_src[i * stride_src + g * output_size + o]);
40 }
41 }
42 }
43 return tensor_dst;
44 }
45
46 // Computes the output for the update or the reset gate.
47 // Operation: `g = sigmoid(W^T∙i + R^T∙s + b)` where
48 // - `g`: output gate vector
49 // - `W`: weights matrix
50 // - `i`: input vector
51 // - `R`: recurrent weights matrix
52 // - `s`: state gate vector
53 // - `b`: bias vector
ComputeUpdateResetGate(int input_size,int output_size,const VectorMath & vector_math,rtc::ArrayView<const float> input,rtc::ArrayView<const float> state,rtc::ArrayView<const float> bias,rtc::ArrayView<const float> weights,rtc::ArrayView<const float> recurrent_weights,rtc::ArrayView<float> gate)54 void ComputeUpdateResetGate(int input_size,
55 int output_size,
56 const VectorMath& vector_math,
57 rtc::ArrayView<const float> input,
58 rtc::ArrayView<const float> state,
59 rtc::ArrayView<const float> bias,
60 rtc::ArrayView<const float> weights,
61 rtc::ArrayView<const float> recurrent_weights,
62 rtc::ArrayView<float> gate) {
63 RTC_DCHECK_EQ(input.size(), input_size);
64 RTC_DCHECK_EQ(state.size(), output_size);
65 RTC_DCHECK_EQ(bias.size(), output_size);
66 RTC_DCHECK_EQ(weights.size(), input_size * output_size);
67 RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
68 RTC_DCHECK_GE(gate.size(), output_size); // `gate` is over-allocated.
69 for (int o = 0; o < output_size; ++o) {
70 float x = bias[o];
71 x += vector_math.DotProduct(input,
72 weights.subview(o * input_size, input_size));
73 x += vector_math.DotProduct(
74 state, recurrent_weights.subview(o * output_size, output_size));
75 gate[o] = ::rnnoise::SigmoidApproximated(x);
76 }
77 }
78
79 // Computes the output for the state gate.
80 // Operation: `s' = u .* s + (1 - u) .* ReLU(W^T∙i + R^T∙(s .* r) + b)` where
81 // - `s'`: output state gate vector
82 // - `s`: previous state gate vector
83 // - `u`: update gate vector
84 // - `W`: weights matrix
85 // - `i`: input vector
86 // - `R`: recurrent weights matrix
87 // - `r`: reset gate vector
88 // - `b`: bias vector
89 // - `.*` element-wise product
ComputeStateGate(int input_size,int output_size,const VectorMath & vector_math,rtc::ArrayView<const float> input,rtc::ArrayView<const float> update,rtc::ArrayView<const float> reset,rtc::ArrayView<const float> bias,rtc::ArrayView<const float> weights,rtc::ArrayView<const float> recurrent_weights,rtc::ArrayView<float> state)90 void ComputeStateGate(int input_size,
91 int output_size,
92 const VectorMath& vector_math,
93 rtc::ArrayView<const float> input,
94 rtc::ArrayView<const float> update,
95 rtc::ArrayView<const float> reset,
96 rtc::ArrayView<const float> bias,
97 rtc::ArrayView<const float> weights,
98 rtc::ArrayView<const float> recurrent_weights,
99 rtc::ArrayView<float> state) {
100 RTC_DCHECK_EQ(input.size(), input_size);
101 RTC_DCHECK_GE(update.size(), output_size); // `update` is over-allocated.
102 RTC_DCHECK_GE(reset.size(), output_size); // `reset` is over-allocated.
103 RTC_DCHECK_EQ(bias.size(), output_size);
104 RTC_DCHECK_EQ(weights.size(), input_size * output_size);
105 RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
106 RTC_DCHECK_EQ(state.size(), output_size);
107 std::array<float, kGruLayerMaxUnits> reset_x_state;
108 for (int o = 0; o < output_size; ++o) {
109 reset_x_state[o] = state[o] * reset[o];
110 }
111 for (int o = 0; o < output_size; ++o) {
112 float x = bias[o];
113 x += vector_math.DotProduct(input,
114 weights.subview(o * input_size, input_size));
115 x += vector_math.DotProduct(
116 {reset_x_state.data(), static_cast<size_t>(output_size)},
117 recurrent_weights.subview(o * output_size, output_size));
118 state[o] = update[o] * state[o] + (1.f - update[o]) * std::max(0.f, x);
119 }
120 }
121
122 } // namespace
123
GatedRecurrentLayer(const int input_size,const int output_size,const rtc::ArrayView<const int8_t> bias,const rtc::ArrayView<const int8_t> weights,const rtc::ArrayView<const int8_t> recurrent_weights,const AvailableCpuFeatures & cpu_features,absl::string_view layer_name)124 GatedRecurrentLayer::GatedRecurrentLayer(
125 const int input_size,
126 const int output_size,
127 const rtc::ArrayView<const int8_t> bias,
128 const rtc::ArrayView<const int8_t> weights,
129 const rtc::ArrayView<const int8_t> recurrent_weights,
130 const AvailableCpuFeatures& cpu_features,
131 absl::string_view layer_name)
132 : input_size_(input_size),
133 output_size_(output_size),
134 bias_(PreprocessGruTensor(bias, output_size)),
135 weights_(PreprocessGruTensor(weights, output_size)),
136 recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)),
137 vector_math_(cpu_features) {
138 RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits)
139 << "Insufficient GRU layer over-allocation (" << layer_name << ").";
140 RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
141 << "Mismatching output size and bias terms array size (" << layer_name
142 << ").";
143 RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size())
144 << "Mismatching input-output size and weight coefficients array size ("
145 << layer_name << ").";
146 RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_,
147 recurrent_weights_.size())
148 << "Mismatching input-output size and recurrent weight coefficients array"
149 " size ("
150 << layer_name << ").";
151 Reset();
152 }
153
154 GatedRecurrentLayer::~GatedRecurrentLayer() = default;
155
Reset()156 void GatedRecurrentLayer::Reset() {
157 state_.fill(0.f);
158 }
159
ComputeOutput(rtc::ArrayView<const float> input)160 void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
161 RTC_DCHECK_EQ(input.size(), input_size_);
162
163 // The tensors below are organized as a sequence of flattened tensors for the
164 // `update`, `reset` and `state` gates.
165 rtc::ArrayView<const float> bias(bias_);
166 rtc::ArrayView<const float> weights(weights_);
167 rtc::ArrayView<const float> recurrent_weights(recurrent_weights_);
168 // Strides to access to the flattened tensors for a specific gate.
169 const int stride_weights = input_size_ * output_size_;
170 const int stride_recurrent_weights = output_size_ * output_size_;
171
172 rtc::ArrayView<float> state(state_.data(), output_size_);
173
174 // Update gate.
175 std::array<float, kGruLayerMaxUnits> update;
176 ComputeUpdateResetGate(
177 input_size_, output_size_, vector_math_, input, state,
178 bias.subview(0, output_size_), weights.subview(0, stride_weights),
179 recurrent_weights.subview(0, stride_recurrent_weights), update);
180 // Reset gate.
181 std::array<float, kGruLayerMaxUnits> reset;
182 ComputeUpdateResetGate(input_size_, output_size_, vector_math_, input, state,
183 bias.subview(output_size_, output_size_),
184 weights.subview(stride_weights, stride_weights),
185 recurrent_weights.subview(stride_recurrent_weights,
186 stride_recurrent_weights),
187 reset);
188 // State gate.
189 ComputeStateGate(input_size_, output_size_, vector_math_, input, update,
190 reset, bias.subview(2 * output_size_, output_size_),
191 weights.subview(2 * stride_weights, stride_weights),
192 recurrent_weights.subview(2 * stride_recurrent_weights,
193 stride_recurrent_weights),
194 state);
195 }
196
197 } // namespace rnn_vad
198 } // namespace webrtc
199