1 /*
2 * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
12
13 #include <algorithm>
14
15 #include "rtc_base/checks.h"
16
17 namespace webrtc {
18 namespace rnn_vad {
19 namespace {
20
21 constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
22 static_assert(1 << kAutoCorrelationFftOrder >
23 kNumLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
24 "");
25
26 } // namespace
27
AutoCorrelationCalculator()28 AutoCorrelationCalculator::AutoCorrelationCalculator()
29 : fft_(1 << kAutoCorrelationFftOrder, Pffft::FftType::kReal),
30 tmp_(fft_.CreateBuffer()),
31 X_(fft_.CreateBuffer()),
32 H_(fft_.CreateBuffer()) {}
33
34 AutoCorrelationCalculator::~AutoCorrelationCalculator() = default;
35
36 // The auto-correlations coefficients are computed as follows:
37 // |.........|...........| <- pitch buffer
38 // [ x (fixed) ]
39 // [ y_0 ]
40 // [ y_{m-1} ]
41 // x and y are sub-array of equal length; x is never moved, whereas y slides.
42 // The cross-correlation between y_0 and x corresponds to the auto-correlation
43 // for the maximum pitch period. Hence, the first value in `auto_corr` has an
44 // inverted lag equal to 0 that corresponds to a lag equal to the maximum
45 // pitch period.
ComputeOnPitchBuffer(rtc::ArrayView<const float,kBufSize12kHz> pitch_buf,rtc::ArrayView<float,kNumLags12kHz> auto_corr)46 void AutoCorrelationCalculator::ComputeOnPitchBuffer(
47 rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
48 rtc::ArrayView<float, kNumLags12kHz> auto_corr) {
49 RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz);
50 RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz);
51 constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder;
52 constexpr int kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
53 static_assert(kConvolutionLength == kFrameSize20ms12kHz,
54 "Mismatch between pitch buffer size, frame size and maximum "
55 "pitch period.");
56 static_assert(kFftFrameSize > kNumLags12kHz + kConvolutionLength,
57 "The FFT length is not sufficiently big to avoid cyclic "
58 "convolution errors.");
59 auto tmp = tmp_->GetView();
60
61 // Compute the FFT for the reversed reference frame - i.e.,
62 // pitch_buf[-kConvolutionLength:].
63 std::reverse_copy(pitch_buf.end() - kConvolutionLength, pitch_buf.end(),
64 tmp.begin());
65 std::fill(tmp.begin() + kConvolutionLength, tmp.end(), 0.f);
66 fft_.ForwardTransform(*tmp_, H_.get(), /*ordered=*/false);
67
68 // Compute the FFT for the sliding frames chunk. The sliding frames are
69 // defined as pitch_buf[i:i+kConvolutionLength] where i in
70 // [0, kNumLags12kHz). The chunk includes all of them, hence it is
71 // defined as pitch_buf[:kNumLags12kHz+kConvolutionLength].
72 std::copy(pitch_buf.begin(),
73 pitch_buf.begin() + kConvolutionLength + kNumLags12kHz,
74 tmp.begin());
75 std::fill(tmp.begin() + kNumLags12kHz + kConvolutionLength, tmp.end(), 0.f);
76 fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false);
77
78 // Convolve in the frequency domain.
79 constexpr float kScalingFactor = 1.f / static_cast<float>(kFftFrameSize);
80 std::fill(tmp.begin(), tmp.end(), 0.f);
81 fft_.FrequencyDomainConvolve(*X_, *H_, tmp_.get(), kScalingFactor);
82 fft_.BackwardTransform(*tmp_, tmp_.get(), /*ordered=*/false);
83
84 // Extract the auto-correlation coefficients.
85 std::copy(tmp.begin() + kConvolutionLength - 1,
86 tmp.begin() + kConvolutionLength + kNumLags12kHz - 1,
87 auto_corr.begin());
88 }
89
90 } // namespace rnn_vad
91 } // namespace webrtc
92