xref: /aosp_15_r20/external/webrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.h (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
12 #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
13 
14 #include <array>
15 #include <vector>
16 
17 #include "absl/strings/string_view.h"
18 #include "api/array_view.h"
19 #include "modules/audio_processing/agc2/cpu_features.h"
20 #include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
21 
22 namespace webrtc {
23 namespace rnn_vad {
24 
25 // Maximum number of units for a GRU layer.
26 constexpr int kGruLayerMaxUnits = 24;
27 
28 // Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
29 // activation functions for the update/reset and output gates respectively.
30 class GatedRecurrentLayer {
31  public:
32   // Ctor. `output_size` cannot be greater than `kGruLayerMaxUnits`.
33   GatedRecurrentLayer(int input_size,
34                       int output_size,
35                       rtc::ArrayView<const int8_t> bias,
36                       rtc::ArrayView<const int8_t> weights,
37                       rtc::ArrayView<const int8_t> recurrent_weights,
38                       const AvailableCpuFeatures& cpu_features,
39                       absl::string_view layer_name);
40   GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
41   GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
42   ~GatedRecurrentLayer();
43 
44   // Returns the size of the input vector.
input_size()45   int input_size() const { return input_size_; }
46   // Returns the pointer to the first element of the output buffer.
data()47   const float* data() const { return state_.data(); }
48   // Returns the size of the output buffer.
size()49   int size() const { return output_size_; }
50 
51   // Resets the GRU state.
52   void Reset();
53   // Computes the recurrent layer output and updates the status.
54   void ComputeOutput(rtc::ArrayView<const float> input);
55 
56  private:
57   const int input_size_;
58   const int output_size_;
59   const std::vector<float> bias_;
60   const std::vector<float> weights_;
61   const std::vector<float> recurrent_weights_;
62   const VectorMath vector_math_;
63   // Over-allocated array with size equal to `output_size_`.
64   std::array<float, kGruLayerMaxUnits> state_;
65 };
66 
67 }  // namespace rnn_vad
68 }  // namespace webrtc
69 
70 #endif  // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
71