xref: /aosp_15_r20/external/webrtc/modules/audio_processing/agc2/rnn_vad/pitch_search.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
12 
13 #include <array>
14 #include <cstddef>
15 
16 #include "rtc_base/checks.h"
17 
18 namespace webrtc {
19 namespace rnn_vad {
20 
PitchEstimator(const AvailableCpuFeatures & cpu_features)21 PitchEstimator::PitchEstimator(const AvailableCpuFeatures& cpu_features)
22     : cpu_features_(cpu_features),
23       y_energy_24kHz_(kRefineNumLags24kHz, 0.f),
24       pitch_buffer_12kHz_(kBufSize12kHz),
25       auto_correlation_12kHz_(kNumLags12kHz) {}
26 
27 PitchEstimator::~PitchEstimator() = default;
28 
Estimate(rtc::ArrayView<const float,kBufSize24kHz> pitch_buffer)29 int PitchEstimator::Estimate(
30     rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
31   rtc::ArrayView<float, kBufSize12kHz> pitch_buffer_12kHz_view(
32       pitch_buffer_12kHz_.data(), kBufSize12kHz);
33   RTC_DCHECK_EQ(pitch_buffer_12kHz_.size(), pitch_buffer_12kHz_view.size());
34   rtc::ArrayView<float, kNumLags12kHz> auto_correlation_12kHz_view(
35       auto_correlation_12kHz_.data(), kNumLags12kHz);
36   RTC_DCHECK_EQ(auto_correlation_12kHz_.size(),
37                 auto_correlation_12kHz_view.size());
38 
39   // TODO(bugs.chromium.org/10480): Use `cpu_features_` to estimate pitch.
40   // Perform the initial pitch search at 12 kHz.
41   Decimate2x(pitch_buffer, pitch_buffer_12kHz_view);
42   auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buffer_12kHz_view,
43                                              auto_correlation_12kHz_view);
44   CandidatePitchPeriods pitch_periods = ComputePitchPeriod12kHz(
45       pitch_buffer_12kHz_view, auto_correlation_12kHz_view, cpu_features_);
46   // The refinement is done using the pitch buffer that contains 24 kHz samples.
47   // Therefore, adapt the inverted lags in `pitch_candidates_inv_lags` from 12
48   // to 24 kHz.
49   pitch_periods.best *= 2;
50   pitch_periods.second_best *= 2;
51 
52   // Refine the initial pitch period estimation from 12 kHz to 48 kHz.
53   // Pre-compute frame energies at 24 kHz.
54   rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_24kHz_view(
55       y_energy_24kHz_.data(), kRefineNumLags24kHz);
56   RTC_DCHECK_EQ(y_energy_24kHz_.size(), y_energy_24kHz_view.size());
57   ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, y_energy_24kHz_view,
58                                          cpu_features_);
59   // Estimation at 48 kHz.
60   const int pitch_lag_48kHz = ComputePitchPeriod48kHz(
61       pitch_buffer, y_energy_24kHz_view, pitch_periods, cpu_features_);
62   last_pitch_48kHz_ = ComputeExtendedPitchPeriod48kHz(
63       pitch_buffer, y_energy_24kHz_view,
64       /*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_lag_48kHz,
65       last_pitch_48kHz_, cpu_features_);
66   return last_pitch_48kHz_.period;
67 }
68 
69 }  // namespace rnn_vad
70 }  // namespace webrtc
71