1 /* 2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_ 12 #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_ 13 14 #include <stddef.h> 15 #include <sys/types.h> 16 17 #include <array> 18 #include <vector> 19 20 #include "api/array_view.h" 21 #include "modules/audio_processing/agc2/cpu_features.h" 22 #include "modules/audio_processing/agc2/rnn_vad/common.h" 23 #include "modules/audio_processing/agc2/rnn_vad/rnn_fc.h" 24 #include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h" 25 26 namespace webrtc { 27 namespace rnn_vad { 28 29 // Recurrent network with hard-coded architecture and weights for voice activity 30 // detection. 31 class RnnVad { 32 public: 33 explicit RnnVad(const AvailableCpuFeatures& cpu_features); 34 RnnVad(const RnnVad&) = delete; 35 RnnVad& operator=(const RnnVad&) = delete; 36 ~RnnVad(); 37 void Reset(); 38 // Observes `feature_vector` and `is_silence`, updates the RNN and returns the 39 // current voice probability. 40 float ComputeVadProbability( 41 rtc::ArrayView<const float, kFeatureVectorSize> feature_vector, 42 bool is_silence); 43 44 private: 45 FullyConnectedLayer input_; 46 GatedRecurrentLayer hidden_; 47 FullyConnectedLayer output_; 48 }; 49 50 } // namespace rnn_vad 51 } // namespace webrtc 52 53 #endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_ 54