xref: /aosp_15_r20/external/webrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
12 
13 #include "rtc_base/checks.h"
14 #include "rtc_base/numerics/safe_conversions.h"
15 #include "third_party/rnnoise/src/rnn_activations.h"
16 #include "third_party/rnnoise/src/rnn_vad_weights.h"
17 
18 namespace webrtc {
19 namespace rnn_vad {
20 namespace {
21 
22 constexpr int kNumGruGates = 3;  // Update, reset, output.
23 
PreprocessGruTensor(rtc::ArrayView<const int8_t> tensor_src,int output_size)24 std::vector<float> PreprocessGruTensor(rtc::ArrayView<const int8_t> tensor_src,
25                                        int output_size) {
26   // Transpose, cast and scale.
27   // `n` is the size of the first dimension of the 3-dim tensor `weights`.
28   const int n = rtc::CheckedDivExact(rtc::dchecked_cast<int>(tensor_src.size()),
29                                      output_size * kNumGruGates);
30   const int stride_src = kNumGruGates * output_size;
31   const int stride_dst = n * output_size;
32   std::vector<float> tensor_dst(tensor_src.size());
33   for (int g = 0; g < kNumGruGates; ++g) {
34     for (int o = 0; o < output_size; ++o) {
35       for (int i = 0; i < n; ++i) {
36         tensor_dst[g * stride_dst + o * n + i] =
37             ::rnnoise::kWeightsScale *
38             static_cast<float>(
39                 tensor_src[i * stride_src + g * output_size + o]);
40       }
41     }
42   }
43   return tensor_dst;
44 }
45 
46 // Computes the output for the update or the reset gate.
47 // Operation: `g = sigmoid(W^T∙i + R^T∙s + b)` where
48 // - `g`: output gate vector
49 // - `W`: weights matrix
50 // - `i`: input vector
51 // - `R`: recurrent weights matrix
52 // - `s`: state gate vector
53 // - `b`: bias vector
ComputeUpdateResetGate(int input_size,int output_size,const VectorMath & vector_math,rtc::ArrayView<const float> input,rtc::ArrayView<const float> state,rtc::ArrayView<const float> bias,rtc::ArrayView<const float> weights,rtc::ArrayView<const float> recurrent_weights,rtc::ArrayView<float> gate)54 void ComputeUpdateResetGate(int input_size,
55                             int output_size,
56                             const VectorMath& vector_math,
57                             rtc::ArrayView<const float> input,
58                             rtc::ArrayView<const float> state,
59                             rtc::ArrayView<const float> bias,
60                             rtc::ArrayView<const float> weights,
61                             rtc::ArrayView<const float> recurrent_weights,
62                             rtc::ArrayView<float> gate) {
63   RTC_DCHECK_EQ(input.size(), input_size);
64   RTC_DCHECK_EQ(state.size(), output_size);
65   RTC_DCHECK_EQ(bias.size(), output_size);
66   RTC_DCHECK_EQ(weights.size(), input_size * output_size);
67   RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
68   RTC_DCHECK_GE(gate.size(), output_size);  // `gate` is over-allocated.
69   for (int o = 0; o < output_size; ++o) {
70     float x = bias[o];
71     x += vector_math.DotProduct(input,
72                                 weights.subview(o * input_size, input_size));
73     x += vector_math.DotProduct(
74         state, recurrent_weights.subview(o * output_size, output_size));
75     gate[o] = ::rnnoise::SigmoidApproximated(x);
76   }
77 }
78 
79 // Computes the output for the state gate.
80 // Operation: `s' = u .* s + (1 - u) .* ReLU(W^T∙i + R^T∙(s .* r) + b)` where
81 // - `s'`: output state gate vector
82 // - `s`: previous state gate vector
83 // - `u`: update gate vector
84 // - `W`: weights matrix
85 // - `i`: input vector
86 // - `R`: recurrent weights matrix
87 // - `r`: reset gate vector
88 // - `b`: bias vector
89 // - `.*` element-wise product
ComputeStateGate(int input_size,int output_size,const VectorMath & vector_math,rtc::ArrayView<const float> input,rtc::ArrayView<const float> update,rtc::ArrayView<const float> reset,rtc::ArrayView<const float> bias,rtc::ArrayView<const float> weights,rtc::ArrayView<const float> recurrent_weights,rtc::ArrayView<float> state)90 void ComputeStateGate(int input_size,
91                       int output_size,
92                       const VectorMath& vector_math,
93                       rtc::ArrayView<const float> input,
94                       rtc::ArrayView<const float> update,
95                       rtc::ArrayView<const float> reset,
96                       rtc::ArrayView<const float> bias,
97                       rtc::ArrayView<const float> weights,
98                       rtc::ArrayView<const float> recurrent_weights,
99                       rtc::ArrayView<float> state) {
100   RTC_DCHECK_EQ(input.size(), input_size);
101   RTC_DCHECK_GE(update.size(), output_size);  // `update` is over-allocated.
102   RTC_DCHECK_GE(reset.size(), output_size);   // `reset` is over-allocated.
103   RTC_DCHECK_EQ(bias.size(), output_size);
104   RTC_DCHECK_EQ(weights.size(), input_size * output_size);
105   RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
106   RTC_DCHECK_EQ(state.size(), output_size);
107   std::array<float, kGruLayerMaxUnits> reset_x_state;
108   for (int o = 0; o < output_size; ++o) {
109     reset_x_state[o] = state[o] * reset[o];
110   }
111   for (int o = 0; o < output_size; ++o) {
112     float x = bias[o];
113     x += vector_math.DotProduct(input,
114                                 weights.subview(o * input_size, input_size));
115     x += vector_math.DotProduct(
116         {reset_x_state.data(), static_cast<size_t>(output_size)},
117         recurrent_weights.subview(o * output_size, output_size));
118     state[o] = update[o] * state[o] + (1.f - update[o]) * std::max(0.f, x);
119   }
120 }
121 
122 }  // namespace
123 
GatedRecurrentLayer(const int input_size,const int output_size,const rtc::ArrayView<const int8_t> bias,const rtc::ArrayView<const int8_t> weights,const rtc::ArrayView<const int8_t> recurrent_weights,const AvailableCpuFeatures & cpu_features,absl::string_view layer_name)124 GatedRecurrentLayer::GatedRecurrentLayer(
125     const int input_size,
126     const int output_size,
127     const rtc::ArrayView<const int8_t> bias,
128     const rtc::ArrayView<const int8_t> weights,
129     const rtc::ArrayView<const int8_t> recurrent_weights,
130     const AvailableCpuFeatures& cpu_features,
131     absl::string_view layer_name)
132     : input_size_(input_size),
133       output_size_(output_size),
134       bias_(PreprocessGruTensor(bias, output_size)),
135       weights_(PreprocessGruTensor(weights, output_size)),
136       recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)),
137       vector_math_(cpu_features) {
138   RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits)
139       << "Insufficient GRU layer over-allocation (" << layer_name << ").";
140   RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
141       << "Mismatching output size and bias terms array size (" << layer_name
142       << ").";
143   RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size())
144       << "Mismatching input-output size and weight coefficients array size ("
145       << layer_name << ").";
146   RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_,
147                 recurrent_weights_.size())
148       << "Mismatching input-output size and recurrent weight coefficients array"
149          " size ("
150       << layer_name << ").";
151   Reset();
152 }
153 
154 GatedRecurrentLayer::~GatedRecurrentLayer() = default;
155 
Reset()156 void GatedRecurrentLayer::Reset() {
157   state_.fill(0.f);
158 }
159 
ComputeOutput(rtc::ArrayView<const float> input)160 void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
161   RTC_DCHECK_EQ(input.size(), input_size_);
162 
163   // The tensors below are organized as a sequence of flattened tensors for the
164   // `update`, `reset` and `state` gates.
165   rtc::ArrayView<const float> bias(bias_);
166   rtc::ArrayView<const float> weights(weights_);
167   rtc::ArrayView<const float> recurrent_weights(recurrent_weights_);
168   // Strides to access to the flattened tensors for a specific gate.
169   const int stride_weights = input_size_ * output_size_;
170   const int stride_recurrent_weights = output_size_ * output_size_;
171 
172   rtc::ArrayView<float> state(state_.data(), output_size_);
173 
174   // Update gate.
175   std::array<float, kGruLayerMaxUnits> update;
176   ComputeUpdateResetGate(
177       input_size_, output_size_, vector_math_, input, state,
178       bias.subview(0, output_size_), weights.subview(0, stride_weights),
179       recurrent_weights.subview(0, stride_recurrent_weights), update);
180   // Reset gate.
181   std::array<float, kGruLayerMaxUnits> reset;
182   ComputeUpdateResetGate(input_size_, output_size_, vector_math_, input, state,
183                          bias.subview(output_size_, output_size_),
184                          weights.subview(stride_weights, stride_weights),
185                          recurrent_weights.subview(stride_recurrent_weights,
186                                                    stride_recurrent_weights),
187                          reset);
188   // State gate.
189   ComputeStateGate(input_size_, output_size_, vector_math_, input, update,
190                    reset, bias.subview(2 * output_size_, output_size_),
191                    weights.subview(2 * stride_weights, stride_weights),
192                    recurrent_weights.subview(2 * stride_recurrent_weights,
193                                              stride_recurrent_weights),
194                    state);
195 }
196 
197 }  // namespace rnn_vad
198 }  // namespace webrtc
199