1 /* 2 * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_ 12 #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_ 13 14 #include <array> 15 #include <vector> 16 17 #include "absl/strings/string_view.h" 18 #include "api/array_view.h" 19 #include "modules/audio_processing/agc2/cpu_features.h" 20 #include "modules/audio_processing/agc2/rnn_vad/vector_math.h" 21 22 namespace webrtc { 23 namespace rnn_vad { 24 25 // Maximum number of units for a GRU layer. 26 constexpr int kGruLayerMaxUnits = 24; 27 28 // Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as 29 // activation functions for the update/reset and output gates respectively. 30 class GatedRecurrentLayer { 31 public: 32 // Ctor. `output_size` cannot be greater than `kGruLayerMaxUnits`. 33 GatedRecurrentLayer(int input_size, 34 int output_size, 35 rtc::ArrayView<const int8_t> bias, 36 rtc::ArrayView<const int8_t> weights, 37 rtc::ArrayView<const int8_t> recurrent_weights, 38 const AvailableCpuFeatures& cpu_features, 39 absl::string_view layer_name); 40 GatedRecurrentLayer(const GatedRecurrentLayer&) = delete; 41 GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete; 42 ~GatedRecurrentLayer(); 43 44 // Returns the size of the input vector. input_size()45 int input_size() const { return input_size_; } 46 // Returns the pointer to the first element of the output buffer. data()47 const float* data() const { return state_.data(); } 48 // Returns the size of the output buffer. size()49 int size() const { return output_size_; } 50 51 // Resets the GRU state. 52 void Reset(); 53 // Computes the recurrent layer output and updates the status. 54 void ComputeOutput(rtc::ArrayView<const float> input); 55 56 private: 57 const int input_size_; 58 const int output_size_; 59 const std::vector<float> bias_; 60 const std::vector<float> weights_; 61 const std::vector<float> recurrent_weights_; 62 const VectorMath vector_math_; 63 // Over-allocated array with size equal to `output_size_`. 64 std::array<float, kGruLayerMaxUnits> state_; 65 }; 66 67 } // namespace rnn_vad 68 } // namespace webrtc 69 70 #endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_ 71