xref: /aosp_15_r20/external/webrtc/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
12 
13 #include <algorithm>
14 
15 #include "rtc_base/checks.h"
16 
17 namespace webrtc {
18 namespace rnn_vad {
19 namespace {
20 
21 constexpr int kAutoCorrelationFftOrder = 9;  // Length-512 FFT.
22 static_assert(1 << kAutoCorrelationFftOrder >
23                   kNumLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
24               "");
25 
26 }  // namespace
27 
AutoCorrelationCalculator()28 AutoCorrelationCalculator::AutoCorrelationCalculator()
29     : fft_(1 << kAutoCorrelationFftOrder, Pffft::FftType::kReal),
30       tmp_(fft_.CreateBuffer()),
31       X_(fft_.CreateBuffer()),
32       H_(fft_.CreateBuffer()) {}
33 
34 AutoCorrelationCalculator::~AutoCorrelationCalculator() = default;
35 
36 // The auto-correlations coefficients are computed as follows:
37 // |.........|...........|  <- pitch buffer
38 //           [ x (fixed) ]
39 // [   y_0   ]
40 //         [ y_{m-1} ]
41 // x and y are sub-array of equal length; x is never moved, whereas y slides.
42 // The cross-correlation between y_0 and x corresponds to the auto-correlation
43 // for the maximum pitch period. Hence, the first value in `auto_corr` has an
44 // inverted lag equal to 0 that corresponds to a lag equal to the maximum
45 // pitch period.
ComputeOnPitchBuffer(rtc::ArrayView<const float,kBufSize12kHz> pitch_buf,rtc::ArrayView<float,kNumLags12kHz> auto_corr)46 void AutoCorrelationCalculator::ComputeOnPitchBuffer(
47     rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
48     rtc::ArrayView<float, kNumLags12kHz> auto_corr) {
49   RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz);
50   RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz);
51   constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder;
52   constexpr int kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
53   static_assert(kConvolutionLength == kFrameSize20ms12kHz,
54                 "Mismatch between pitch buffer size, frame size and maximum "
55                 "pitch period.");
56   static_assert(kFftFrameSize > kNumLags12kHz + kConvolutionLength,
57                 "The FFT length is not sufficiently big to avoid cyclic "
58                 "convolution errors.");
59   auto tmp = tmp_->GetView();
60 
61   // Compute the FFT for the reversed reference frame - i.e.,
62   // pitch_buf[-kConvolutionLength:].
63   std::reverse_copy(pitch_buf.end() - kConvolutionLength, pitch_buf.end(),
64                     tmp.begin());
65   std::fill(tmp.begin() + kConvolutionLength, tmp.end(), 0.f);
66   fft_.ForwardTransform(*tmp_, H_.get(), /*ordered=*/false);
67 
68   // Compute the FFT for the sliding frames chunk. The sliding frames are
69   // defined as pitch_buf[i:i+kConvolutionLength] where i in
70   // [0, kNumLags12kHz). The chunk includes all of them, hence it is
71   // defined as pitch_buf[:kNumLags12kHz+kConvolutionLength].
72   std::copy(pitch_buf.begin(),
73             pitch_buf.begin() + kConvolutionLength + kNumLags12kHz,
74             tmp.begin());
75   std::fill(tmp.begin() + kNumLags12kHz + kConvolutionLength, tmp.end(), 0.f);
76   fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false);
77 
78   // Convolve in the frequency domain.
79   constexpr float kScalingFactor = 1.f / static_cast<float>(kFftFrameSize);
80   std::fill(tmp.begin(), tmp.end(), 0.f);
81   fft_.FrequencyDomainConvolve(*X_, *H_, tmp_.get(), kScalingFactor);
82   fft_.BackwardTransform(*tmp_, tmp_.get(), /*ordered=*/false);
83 
84   // Extract the auto-correlation coefficients.
85   std::copy(tmp.begin() + kConvolutionLength - 1,
86             tmp.begin() + kConvolutionLength + kNumLags12kHz - 1,
87             auto_corr.begin());
88 }
89 
90 }  // namespace rnn_vad
91 }  // namespace webrtc
92