xref: /aosp_15_r20/external/libaom/test/fft_test.cc (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <math.h>
13 
14 #include <algorithm>
15 #include <complex>
16 #include <ostream>
17 #include <vector>
18 
19 #include "aom_dsp/fft_common.h"
20 #include "aom_mem/aom_mem.h"
21 #include "av1/common/common.h"
22 #include "config/aom_dsp_rtcd.h"
23 #include "gtest/gtest.h"
24 #include "test/acm_random.h"
25 
26 namespace {
27 
28 typedef void (*tform_fun_t)(const float *input, float *temp, float *output);
29 
30 // Simple 1D FFT implementation
31 template <typename InputType>
fft(const InputType * data,std::complex<float> * result,int n)32 void fft(const InputType *data, std::complex<float> *result, int n) {
33   if (n == 1) {
34     result[0] = data[0];
35     return;
36   }
37   std::vector<InputType> temp(n);
38   for (int k = 0; k < n / 2; ++k) {
39     temp[k] = data[2 * k];
40     temp[n / 2 + k] = data[2 * k + 1];
41   }
42   fft(&temp[0], result, n / 2);
43   fft(&temp[n / 2], result + n / 2, n / 2);
44   for (int k = 0; k < n / 2; ++k) {
45     std::complex<float> w = std::complex<float>((float)cos(2. * PI * k / n),
46                                                 (float)-sin(2. * PI * k / n));
47     std::complex<float> a = result[k];
48     std::complex<float> b = result[n / 2 + k];
49     result[k] = a + w * b;
50     result[n / 2 + k] = a - w * b;
51   }
52 }
53 
transpose(std::vector<std::complex<float>> * data,int n)54 void transpose(std::vector<std::complex<float> > *data, int n) {
55   for (int y = 0; y < n; ++y) {
56     for (int x = y + 1; x < n; ++x) {
57       std::swap((*data)[y * n + x], (*data)[x * n + y]);
58     }
59   }
60 }
61 
62 // Simple 2D FFT implementation
63 template <class InputType>
fft2d(const InputType * input,int n)64 std::vector<std::complex<float> > fft2d(const InputType *input, int n) {
65   std::vector<std::complex<float> > rowfft(n * n);
66   std::vector<std::complex<float> > result(n * n);
67   for (int y = 0; y < n; ++y) {
68     fft(input + y * n, &rowfft[y * n], n);
69   }
70   transpose(&rowfft, n);
71   for (int y = 0; y < n; ++y) {
72     fft(&rowfft[y * n], &result[y * n], n);
73   }
74   transpose(&result, n);
75   return result;
76 }
77 
78 struct FFTTestArg {
79   int n;
80   void (*fft)(const float *input, float *temp, float *output);
FFTTestArg__anone4177d4c0111::FFTTestArg81   FFTTestArg(int n_in, tform_fun_t fft_in) : n(n_in), fft(fft_in) {}
82 };
83 
operator <<(std::ostream & os,const FFTTestArg & test_arg)84 std::ostream &operator<<(std::ostream &os, const FFTTestArg &test_arg) {
85   return os << "fft_arg { n:" << test_arg.n
86             << " fft:" << reinterpret_cast<const void *>(test_arg.fft) << " }";
87 }
88 
89 class FFT2DTest : public ::testing::TestWithParam<FFTTestArg> {
90  protected:
SetUp()91   void SetUp() override {
92     int n = GetParam().n;
93     input_ = (float *)aom_memalign(32, sizeof(*input_) * n * n);
94     temp_ = (float *)aom_memalign(32, sizeof(*temp_) * n * n);
95     output_ = (float *)aom_memalign(32, sizeof(*output_) * n * n * 2);
96     ASSERT_NE(input_, nullptr);
97     ASSERT_NE(temp_, nullptr);
98     ASSERT_NE(output_, nullptr);
99     memset(input_, 0, sizeof(*input_) * n * n);
100     memset(temp_, 0, sizeof(*temp_) * n * n);
101     memset(output_, 0, sizeof(*output_) * n * n * 2);
102   }
TearDown()103   void TearDown() override {
104     aom_free(input_);
105     aom_free(temp_);
106     aom_free(output_);
107   }
108   float *input_;
109   float *temp_;
110   float *output_;
111 };
112 
TEST_P(FFT2DTest,Correct)113 TEST_P(FFT2DTest, Correct) {
114   int n = GetParam().n;
115   for (int i = 0; i < n * n; ++i) {
116     input_[i] = 1;
117     std::vector<std::complex<float> > expected = fft2d<float>(&input_[0], n);
118     GetParam().fft(&input_[0], &temp_[0], &output_[0]);
119     for (int y = 0; y < n; ++y) {
120       for (int x = 0; x < (n / 2) + 1; ++x) {
121         EXPECT_NEAR(expected[y * n + x].real(), output_[2 * (y * n + x)], 1e-5);
122         EXPECT_NEAR(expected[y * n + x].imag(), output_[2 * (y * n + x) + 1],
123                     1e-5);
124       }
125     }
126     input_[i] = 0;
127   }
128 }
129 
TEST_P(FFT2DTest,Benchmark)130 TEST_P(FFT2DTest, Benchmark) {
131   int n = GetParam().n;
132   float sum = 0;
133   const int num_trials = 1000 * (64 - n);
134   for (int i = 0; i < num_trials; ++i) {
135     input_[i % (n * n)] = 1;
136     GetParam().fft(&input_[0], &temp_[0], &output_[0]);
137     sum += output_[0];
138     input_[i % (n * n)] = 0;
139   }
140   EXPECT_NEAR(sum, num_trials, 1e-3);
141 }
142 
143 INSTANTIATE_TEST_SUITE_P(C, FFT2DTest,
144                          ::testing::Values(FFTTestArg(2, aom_fft2x2_float_c),
145                                            FFTTestArg(4, aom_fft4x4_float_c),
146                                            FFTTestArg(8, aom_fft8x8_float_c),
147                                            FFTTestArg(16, aom_fft16x16_float_c),
148                                            FFTTestArg(32,
149                                                       aom_fft32x32_float_c)));
150 #if AOM_ARCH_X86 || AOM_ARCH_X86_64
151 #if HAVE_SSE2
152 INSTANTIATE_TEST_SUITE_P(
153     SSE2, FFT2DTest,
154     ::testing::Values(FFTTestArg(4, aom_fft4x4_float_sse2),
155                       FFTTestArg(8, aom_fft8x8_float_sse2),
156                       FFTTestArg(16, aom_fft16x16_float_sse2),
157                       FFTTestArg(32, aom_fft32x32_float_sse2)));
158 #endif  // HAVE_SSE2
159 #if HAVE_AVX2
160 INSTANTIATE_TEST_SUITE_P(
161     AVX2, FFT2DTest,
162     ::testing::Values(FFTTestArg(8, aom_fft8x8_float_avx2),
163                       FFTTestArg(16, aom_fft16x16_float_avx2),
164                       FFTTestArg(32, aom_fft32x32_float_avx2)));
165 #endif  // HAVE_AVX2
166 #endif  // AOM_ARCH_X86 || AOM_ARCH_X86_64
167 
168 struct IFFTTestArg {
169   int n;
170   tform_fun_t ifft;
IFFTTestArg__anone4177d4c0111::IFFTTestArg171   IFFTTestArg(int n_in, tform_fun_t ifft_in) : n(n_in), ifft(ifft_in) {}
172 };
173 
operator <<(std::ostream & os,const IFFTTestArg & test_arg)174 std::ostream &operator<<(std::ostream &os, const IFFTTestArg &test_arg) {
175   return os << "ifft_arg { n:" << test_arg.n
176             << " fft:" << reinterpret_cast<const void *>(test_arg.ifft) << " }";
177 }
178 
179 class IFFT2DTest : public ::testing::TestWithParam<IFFTTestArg> {
180  protected:
SetUp()181   void SetUp() override {
182     int n = GetParam().n;
183     input_ = (float *)aom_memalign(32, sizeof(*input_) * n * n * 2);
184     temp_ = (float *)aom_memalign(32, sizeof(*temp_) * n * n * 2);
185     output_ = (float *)aom_memalign(32, sizeof(*output_) * n * n);
186     ASSERT_NE(input_, nullptr);
187     ASSERT_NE(temp_, nullptr);
188     ASSERT_NE(output_, nullptr);
189     memset(input_, 0, sizeof(*input_) * n * n * 2);
190     memset(temp_, 0, sizeof(*temp_) * n * n * 2);
191     memset(output_, 0, sizeof(*output_) * n * n);
192   }
TearDown()193   void TearDown() override {
194     aom_free(input_);
195     aom_free(temp_);
196     aom_free(output_);
197   }
198   float *input_;
199   float *temp_;
200   float *output_;
201 };
202 
TEST_P(IFFT2DTest,Correctness)203 TEST_P(IFFT2DTest, Correctness) {
204   int n = GetParam().n;
205   ASSERT_GE(n, 2);
206   std::vector<float> expected(n * n);
207   std::vector<float> actual(n * n);
208   // Do forward transform then invert to make sure we get back expected
209   for (int y = 0; y < n; ++y) {
210     for (int x = 0; x < n; ++x) {
211       expected[y * n + x] = 1;
212       std::vector<std::complex<float> > input_c = fft2d(&expected[0], n);
213       for (int i = 0; i < n * n; ++i) {
214         input_[2 * i + 0] = input_c[i].real();
215         input_[2 * i + 1] = input_c[i].imag();
216       }
217       GetParam().ifft(&input_[0], &temp_[0], &output_[0]);
218 
219       for (int yy = 0; yy < n; ++yy) {
220         for (int xx = 0; xx < n; ++xx) {
221           EXPECT_NEAR(expected[yy * n + xx], output_[yy * n + xx] / (n * n),
222                       1e-5);
223         }
224       }
225       expected[y * n + x] = 0;
226     }
227   }
228 }
229 
TEST_P(IFFT2DTest,Benchmark)230 TEST_P(IFFT2DTest, Benchmark) {
231   int n = GetParam().n;
232   float sum = 0;
233   const int num_trials = 1000 * (64 - n);
234   for (int i = 0; i < num_trials; ++i) {
235     input_[i % (n * n)] = 1;
236     GetParam().ifft(&input_[0], &temp_[0], &output_[0]);
237     sum += output_[0];
238     input_[i % (n * n)] = 0;
239   }
240   EXPECT_GE(sum, num_trials / 2);
241 }
242 INSTANTIATE_TEST_SUITE_P(
243     C, IFFT2DTest,
244     ::testing::Values(IFFTTestArg(2, aom_ifft2x2_float_c),
245                       IFFTTestArg(4, aom_ifft4x4_float_c),
246                       IFFTTestArg(8, aom_ifft8x8_float_c),
247                       IFFTTestArg(16, aom_ifft16x16_float_c),
248                       IFFTTestArg(32, aom_ifft32x32_float_c)));
249 #if AOM_ARCH_X86 || AOM_ARCH_X86_64
250 #if HAVE_SSE2
251 INSTANTIATE_TEST_SUITE_P(
252     SSE2, IFFT2DTest,
253     ::testing::Values(IFFTTestArg(4, aom_ifft4x4_float_sse2),
254                       IFFTTestArg(8, aom_ifft8x8_float_sse2),
255                       IFFTTestArg(16, aom_ifft16x16_float_sse2),
256                       IFFTTestArg(32, aom_ifft32x32_float_sse2)));
257 #endif  // HAVE_SSE2
258 
259 #if HAVE_AVX2
260 INSTANTIATE_TEST_SUITE_P(
261     AVX2, IFFT2DTest,
262     ::testing::Values(IFFTTestArg(8, aom_ifft8x8_float_avx2),
263                       IFFTTestArg(16, aom_ifft16x16_float_avx2),
264                       IFFTTestArg(32, aom_ifft32x32_float_avx2)));
265 #endif  // HAVE_AVX2
266 #endif  // AOM_ARCH_X86 || AOM_ARCH_X86_64
267 
268 }  // namespace
269