1 /* 2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SYMMETRIC_MATRIX_BUFFER_H_ 12 #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SYMMETRIC_MATRIX_BUFFER_H_ 13 14 #include <algorithm> 15 #include <array> 16 #include <cstring> 17 #include <utility> 18 19 #include "api/array_view.h" 20 #include "rtc_base/checks.h" 21 #include "rtc_base/numerics/safe_compare.h" 22 23 namespace webrtc { 24 namespace rnn_vad { 25 26 // Data structure to buffer the results of pair-wise comparisons between items 27 // stored in a ring buffer. Every time that the oldest item is replaced in the 28 // ring buffer, the new one is compared to the remaining items in the ring 29 // buffer. The results of such comparisons need to be buffered and automatically 30 // removed when one of the two corresponding items that have been compared is 31 // removed from the ring buffer. It is assumed that the comparison is symmetric 32 // and that comparing an item with itself is not needed. 33 template <typename T, int S> 34 class SymmetricMatrixBuffer { 35 static_assert(S > 2, ""); 36 37 public: 38 SymmetricMatrixBuffer() = default; 39 SymmetricMatrixBuffer(const SymmetricMatrixBuffer&) = delete; 40 SymmetricMatrixBuffer& operator=(const SymmetricMatrixBuffer&) = delete; 41 ~SymmetricMatrixBuffer() = default; 42 // Sets the buffer values to zero. Reset()43 void Reset() { 44 static_assert(std::is_arithmetic<T>::value, 45 "Integral or floating point required."); 46 buf_.fill(0); 47 } 48 // Pushes the results from the comparison between the most recent item and 49 // those that are still in the ring buffer. The first element in `values` must 50 // correspond to the comparison between the most recent item and the second 51 // most recent one in the ring buffer, whereas the last element in `values` 52 // must correspond to the comparison between the most recent item and the 53 // oldest one in the ring buffer. Push(rtc::ArrayView<T,S-1> values)54 void Push(rtc::ArrayView<T, S - 1> values) { 55 // Move the lower-right sub-matrix of size (S-2) x (S-2) one row up and one 56 // column left. 57 std::memmove(buf_.data(), buf_.data() + S, (buf_.size() - S) * sizeof(T)); 58 // Copy new values in the last column in the right order. 59 for (int i = 0; rtc::SafeLt(i, values.size()); ++i) { 60 const int index = (S - 1 - i) * (S - 1) - 1; 61 RTC_DCHECK_GE(index, 0); 62 RTC_DCHECK_LT(index, buf_.size()); 63 buf_[index] = values[i]; 64 } 65 } 66 // Reads the value that corresponds to comparison of two items in the ring 67 // buffer having delay `delay1` and `delay2`. The two arguments must not be 68 // equal and both must be in {0, ..., S - 1}. GetValue(int delay1,int delay2)69 T GetValue(int delay1, int delay2) const { 70 int row = S - 1 - delay1; 71 int col = S - 1 - delay2; 72 RTC_DCHECK_NE(row, col) << "The diagonal cannot be accessed."; 73 if (row > col) 74 std::swap(row, col); // Swap to access the upper-right triangular part. 75 RTC_DCHECK_LE(0, row); 76 RTC_DCHECK_LT(row, S - 1) << "Not enforcing row < col and row != col."; 77 RTC_DCHECK_LE(1, col) << "Not enforcing row < col and row != col."; 78 RTC_DCHECK_LT(col, S); 79 const int index = row * (S - 1) + (col - 1); 80 RTC_DCHECK_LE(0, index); 81 RTC_DCHECK_LT(index, buf_.size()); 82 return buf_[index]; 83 } 84 85 private: 86 // Encode an upper-right triangular matrix (excluding its diagonal) using a 87 // square matrix. This allows to move the data in Push() with one single 88 // operation. 89 std::array<T, (S - 1) * (S - 1)> buf_{}; 90 }; 91 92 } // namespace rnn_vad 93 } // namespace webrtc 94 95 #endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SYMMETRIC_MATRIX_BUFFER_H_ 96