1 /*
2 * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/utility/pffft_wrapper.h"
12
13 #include <algorithm>
14 #include <cstdlib>
15 #include <memory>
16
17 #include "test/gtest.h"
18 #include "third_party/pffft/src/pffft.h"
19
20 namespace webrtc {
21 namespace test {
22 namespace {
23
24 constexpr size_t kMaxValidSizeCheck = 1024;
25
26 static constexpr int kFftSizes[] = {
27 16, 32, 64, 96, 128, 160, 192, 256, 288, 384, 5 * 96, 512,
28 576, 5 * 128, 800, 864, 1024, 2048, 2592, 4000, 4096, 12000, 36864};
29
CreatePffftWrapper(size_t fft_size,Pffft::FftType fft_type)30 void CreatePffftWrapper(size_t fft_size, Pffft::FftType fft_type) {
31 Pffft pffft_wrapper(fft_size, fft_type);
32 }
33
AllocateScratchBuffer(size_t fft_size,bool complex_fft)34 float* AllocateScratchBuffer(size_t fft_size, bool complex_fft) {
35 return static_cast<float*>(
36 pffft_aligned_malloc(fft_size * (complex_fft ? 2 : 1) * sizeof(float)));
37 }
38
frand()39 double frand() {
40 return std::rand() / static_cast<double>(RAND_MAX);
41 }
42
ExpectArrayViewsEquality(rtc::ArrayView<const float> a,rtc::ArrayView<const float> b)43 void ExpectArrayViewsEquality(rtc::ArrayView<const float> a,
44 rtc::ArrayView<const float> b) {
45 ASSERT_EQ(a.size(), b.size());
46 for (size_t i = 0; i < a.size(); ++i) {
47 SCOPED_TRACE(i);
48 EXPECT_EQ(a[i], b[i]);
49 }
50 }
51
52 // Compares the output of the PFFFT C++ wrapper to that of the C PFFFT.
53 // Bit-exactness is expected.
PffftValidateWrapper(size_t fft_size,bool complex_fft)54 void PffftValidateWrapper(size_t fft_size, bool complex_fft) {
55 // Always use the same seed to avoid flakiness.
56 std::srand(0);
57
58 // Init PFFFT.
59 PFFFT_Setup* pffft_status =
60 pffft_new_setup(fft_size, complex_fft ? PFFFT_COMPLEX : PFFFT_REAL);
61 ASSERT_TRUE(pffft_status) << "FFT size (" << fft_size << ") not supported.";
62 size_t num_floats = fft_size * (complex_fft ? 2 : 1);
63 int num_bytes = static_cast<int>(num_floats) * sizeof(float);
64 float* in = static_cast<float*>(pffft_aligned_malloc(num_bytes));
65 float* out = static_cast<float*>(pffft_aligned_malloc(num_bytes));
66 float* scratch = AllocateScratchBuffer(fft_size, complex_fft);
67
68 // Init PFFFT C++ wrapper.
69 Pffft::FftType fft_type =
70 complex_fft ? Pffft::FftType::kComplex : Pffft::FftType::kReal;
71 ASSERT_TRUE(Pffft::IsValidFftSize(fft_size, fft_type));
72 Pffft pffft_wrapper(fft_size, fft_type);
73 auto in_wrapper = pffft_wrapper.CreateBuffer();
74 auto out_wrapper = pffft_wrapper.CreateBuffer();
75
76 // Input and output buffers views.
77 rtc::ArrayView<float> in_view(in, num_floats);
78 rtc::ArrayView<float> out_view(out, num_floats);
79 auto in_wrapper_view = in_wrapper->GetView();
80 EXPECT_EQ(in_wrapper_view.size(), num_floats);
81 auto out_wrapper_view = out_wrapper->GetConstView();
82 EXPECT_EQ(out_wrapper_view.size(), num_floats);
83
84 // Random input data.
85 for (size_t i = 0; i < num_floats; ++i) {
86 in_wrapper_view[i] = in[i] = static_cast<float>(frand() * 2.0 - 1.0);
87 }
88
89 // Forward transform.
90 pffft_transform(pffft_status, in, out, scratch, PFFFT_FORWARD);
91 pffft_wrapper.ForwardTransform(*in_wrapper, out_wrapper.get(),
92 /*ordered=*/false);
93 ExpectArrayViewsEquality(out_view, out_wrapper_view);
94
95 // Copy the FFT results into the input buffers to compute the backward FFT.
96 std::copy(out_view.begin(), out_view.end(), in_view.begin());
97 std::copy(out_wrapper_view.begin(), out_wrapper_view.end(),
98 in_wrapper_view.begin());
99
100 // Backward transform.
101 pffft_transform(pffft_status, in, out, scratch, PFFFT_BACKWARD);
102 pffft_wrapper.BackwardTransform(*in_wrapper, out_wrapper.get(),
103 /*ordered=*/false);
104 ExpectArrayViewsEquality(out_view, out_wrapper_view);
105
106 pffft_destroy_setup(pffft_status);
107 pffft_aligned_free(in);
108 pffft_aligned_free(out);
109 pffft_aligned_free(scratch);
110 }
111
112 } // namespace
113
TEST(PffftTest,CreateWrapperWithValidSize)114 TEST(PffftTest, CreateWrapperWithValidSize) {
115 for (size_t fft_size = 0; fft_size < kMaxValidSizeCheck; ++fft_size) {
116 SCOPED_TRACE(fft_size);
117 if (Pffft::IsValidFftSize(fft_size, Pffft::FftType::kReal)) {
118 CreatePffftWrapper(fft_size, Pffft::FftType::kReal);
119 }
120 if (Pffft::IsValidFftSize(fft_size, Pffft::FftType::kComplex)) {
121 CreatePffftWrapper(fft_size, Pffft::FftType::kComplex);
122 }
123 }
124 }
125
126 #if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
127
128 class PffftInvalidSizeDeathTest : public ::testing::Test,
129 public ::testing::WithParamInterface<size_t> {
130 };
131
TEST_P(PffftInvalidSizeDeathTest,DoNotCreateRealWrapper)132 TEST_P(PffftInvalidSizeDeathTest, DoNotCreateRealWrapper) {
133 size_t fft_size = GetParam();
134 ASSERT_FALSE(Pffft::IsValidFftSize(fft_size, Pffft::FftType::kReal));
135 EXPECT_DEATH(CreatePffftWrapper(fft_size, Pffft::FftType::kReal), "");
136 }
137
TEST_P(PffftInvalidSizeDeathTest,DoNotCreateComplexWrapper)138 TEST_P(PffftInvalidSizeDeathTest, DoNotCreateComplexWrapper) {
139 size_t fft_size = GetParam();
140 ASSERT_FALSE(Pffft::IsValidFftSize(fft_size, Pffft::FftType::kComplex));
141 EXPECT_DEATH(CreatePffftWrapper(fft_size, Pffft::FftType::kComplex), "");
142 }
143
144 INSTANTIATE_TEST_SUITE_P(PffftTest,
145 PffftInvalidSizeDeathTest,
146 ::testing::Values(17,
147 33,
148 65,
149 97,
150 129,
151 161,
152 193,
153 257,
154 289,
155 385,
156 481,
157 513,
158 577,
159 641,
160 801,
161 865,
162 1025));
163
164 #endif
165
166 // TODO(https://crbug.com/webrtc/9577): Enable once SIMD is always enabled.
TEST(PffftTest,DISABLED_CheckSimd)167 TEST(PffftTest, DISABLED_CheckSimd) {
168 EXPECT_TRUE(Pffft::IsSimdEnabled());
169 }
170
TEST(PffftTest,FftBitExactness)171 TEST(PffftTest, FftBitExactness) {
172 for (int fft_size : kFftSizes) {
173 SCOPED_TRACE(fft_size);
174 if (fft_size != 16) {
175 PffftValidateWrapper(fft_size, /*complex_fft=*/false);
176 }
177 PffftValidateWrapper(fft_size, /*complex_fft=*/true);
178 }
179 }
180
181 } // namespace test
182 } // namespace webrtc
183