1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #pragma once
7
8 #include <gtest/gtest.h>
9
10 #include <algorithm>
11 #include <cassert>
12 #include <cmath>
13 #include <cstddef>
14 #include <cstdlib>
15 #include <random>
16 #include <vector>
17
18 #include <xnnpack.h>
19 #include <xnnpack/aligned-allocator.h>
20 #include <xnnpack/math.h>
21 #include <xnnpack/params.h>
22
23 // twiddle table for bfly4 for fft size 256 (complex numbers)
24 // Even numbers are numpy.floor(0.5 + 32767 * numpy.cos(-2*pi*numpy.linspace(0, 255, num=256) / 256)).astype(numpy.int16).tolist()
25 // Odd numbers are numpy.floor(0.5 + 32767 * numpy.sin(-2*pi*numpy.linspace(0, 255, num=256) / 256)).astype(numpy.int16).tolist()
26
27 static const int16_t xnn_reference_table_fft256_twiddle[512] = {
28 32767, 0, 32757, -804, 32728, -1608, 32678, -2410,
29 32609, -3212, 32521, -4011, 32412, -4808, 32285, -5602,
30 32137, -6393, 31971, -7179, 31785, -7962, 31580, -8739,
31 31356, -9512, 31113,-10278, 30852,-11039, 30571,-11793,
32 30273,-12539, 29956,-13279, 29621,-14010, 29268,-14732,
33 28898,-15446, 28510,-16151, 28105,-16846, 27683,-17530,
34 27245,-18204, 26790,-18868, 26319,-19519, 25832,-20159,
35 25329,-20787, 24811,-21403, 24279,-22005, 23731,-22594,
36 23170,-23170, 22594,-23731, 22005,-24279, 21403,-24811,
37 20787,-25329, 20159,-25832, 19519,-26319, 18868,-26790,
38 18204,-27245, 17530,-27683, 16846,-28105, 16151,-28510,
39 15446,-28898, 14732,-29268, 14010,-29621, 13279,-29956,
40 12539,-30273, 11793,-30571, 11039,-30852, 10278,-31113,
41 9512,-31356, 8739,-31580, 7962,-31785, 7179,-31971,
42 6393,-32137, 5602,-32285, 4808,-32412, 4011,-32521,
43 3212,-32609, 2410,-32678, 1608,-32728, 804,-32757,
44 0,-32767, -804,-32757, -1608,-32728, -2410,-32678,
45 -3212,-32609, -4011,-32521, -4808,-32412, -5602,-32285,
46 -6393,-32137, -7179,-31971, -7962,-31785, -8739,-31580,
47 -9512,-31356, -10278,-31113, -11039,-30852, -11793,-30571,
48 -12539,-30273, -13279,-29956, -14010,-29621, -14732,-29268,
49 -15446,-28898, -16151,-28510, -16846,-28105, -17530,-27683,
50 -18204,-27245, -18868,-26790, -19519,-26319, -20159,-25832,
51 -20787,-25329, -21403,-24811, -22005,-24279, -22594,-23731,
52 -23170,-23170, -23731,-22594, -24279,-22005, -24811,-21403,
53 -25329,-20787, -25832,-20159, -26319,-19519, -26790,-18868,
54 -27245,-18204, -27683,-17530, -28105,-16846, -28510,-16151,
55 -28898,-15446, -29268,-14732, -29621,-14010, -29956,-13279,
56 -30273,-12539, -30571,-11793, -30852,-11039, -31113,-10278,
57 -31356, -9512, -31580, -8739, -31785, -7962, -31971, -7179,
58 -32137, -6393, -32285, -5602, -32412, -4808, -32521, -4011,
59 -32609, -3212, -32678, -2410, -32728, -1608, -32757, -804,
60 -32767, 0, -32757, 804, -32728, 1608, -32678, 2410,
61 -32609, 3212, -32521, 4011, -32412, 4808, -32285, 5602,
62 -32137, 6393, -31971, 7179, -31785, 7962, -31580, 8739,
63 -31356, 9512, -31113, 10278, -30852, 11039, -30571, 11793,
64 -30273, 12539, -29956, 13279, -29621, 14010, -29268, 14732,
65 -28898, 15446, -28510, 16151, -28105, 16846, -27683, 17530,
66 -27245, 18204, -26790, 18868, -26319, 19519, -25832, 20159,
67 -25329, 20787, -24811, 21403, -24279, 22005, -23731, 22594,
68 -23170, 23170, -22594, 23731, -22005, 24279, -21403, 24811,
69 -20787, 25329, -20159, 25832, -19519, 26319, -18868, 26790,
70 -18204, 27245, -17530, 27683, -16846, 28105, -16151, 28510,
71 -15446, 28898, -14732, 29268, -14010, 29621, -13279, 29956,
72 -12539, 30273, -11793, 30571, -11039, 30852, -10278, 31113,
73 -9512, 31356, -8739, 31580, -7962, 31785, -7179, 31971,
74 -6393, 32137, -5602, 32285, -4808, 32412, -4011, 32521,
75 -3212, 32609, -2410, 32678, -1608, 32728, -804, 32757,
76 0, 32767, 804, 32757, 1608, 32728, 2410, 32678,
77 3212, 32609, 4011, 32521, 4808, 32412, 5602, 32285,
78 6393, 32137, 7179, 31971, 7962, 31785, 8739, 31580,
79 9512, 31356, 10278, 31113, 11039, 30852, 11793, 30571,
80 12539, 30273, 13279, 29956, 14010, 29621, 14732, 29268,
81 15446, 28898, 16151, 28510, 16846, 28105, 17530, 27683,
82 18204, 27245, 18868, 26790, 19519, 26319, 20159, 25832,
83 20787, 25329, 21403, 24811, 22005, 24279, 22594, 23731,
84 23170, 23170, 23731, 22594, 24279, 22005, 24811, 21403,
85 25329, 20787, 25832, 20159, 26319, 19519, 26790, 18868,
86 27245, 18204, 27683, 17530, 28105, 16846, 28510, 16151,
87 28898, 15446, 29268, 14732, 29621, 14010, 29956, 13279,
88 30273, 12539, 30571, 11793, 30852, 11039, 31113, 10278,
89 31356, 9512, 31580, 8739, 31785, 7962, 31971, 7179,
90 32137, 6393, 32285, 5602, 32412, 4808, 32521, 4011,
91 32609, 3212, 32678, 2410, 32728, 1608, 32757, 804
92 };
93
xnn_cs16_bfly4_reference(size_t samples,int16_t * data,const size_t stride,const int16_t * twiddle)94 void xnn_cs16_bfly4_reference(
95 size_t samples,
96 int16_t* data,
97 const size_t stride,
98 const int16_t* twiddle) {
99
100 const int16_t* tw1 = twiddle;
101 const int16_t* tw2 = tw1;
102 const int16_t* tw3 = tw1;
103 int16_t* out0 = data;
104 int16_t* out1 = data + samples * 2;
105 int16_t* out2 = data + samples * 4;
106 int16_t* out3 = data + samples * 6;
107
108 assert(samples != 0);
109 assert(stride != 0);
110 assert(twiddle != NULL);
111 assert(data != NULL);
112
113 do {
114 int32_t vout0_r = (int32_t) out0[0];
115 int32_t vout0_i = (int32_t) out0[1];
116 int32_t vout1_r = (int32_t) out1[0];
117 int32_t vout1_i = (int32_t) out1[1];
118 int32_t vout2_r = (int32_t) out2[0];
119 int32_t vout2_i = (int32_t) out2[1];
120 int32_t vout3_r = (int32_t) out3[0];
121 int32_t vout3_i = (int32_t) out3[1];
122
123 const int32_t tw1_r = (const int32_t) tw1[0];
124 const int32_t tw1_i = (const int32_t) tw1[1];
125 const int32_t tw2_r = (const int32_t) tw2[0];
126 const int32_t tw2_i = (const int32_t) tw2[1];
127 const int32_t tw3_r = (const int32_t) tw3[0];
128 const int32_t tw3_i = (const int32_t) tw3[1];
129
130 // Note 32767 / 4 = 8191. Should be 8192.
131 vout0_r = (vout0_r * 8191 + 16384) >> 15;
132 vout0_i = (vout0_i * 8191 + 16384) >> 15;
133 vout1_r = (vout1_r * 8191 + 16384) >> 15;
134 vout1_i = (vout1_i * 8191 + 16384) >> 15;
135 vout2_r = (vout2_r * 8191 + 16384) >> 15;
136 vout2_i = (vout2_i * 8191 + 16384) >> 15;
137 vout3_r = (vout3_r * 8191 + 16384) >> 15;
138 vout3_i = (vout3_i * 8191 + 16384) >> 15;
139
140 const int32_t vtmp0_r = math_asr_s32(vout1_r * tw1_r - vout1_i * tw1_i + 16384, 15);
141 const int32_t vtmp0_i = math_asr_s32(vout1_r * tw1_i + vout1_i * tw1_r + 16384, 15);
142 const int32_t vtmp1_r = math_asr_s32(vout2_r * tw2_r - vout2_i * tw2_i + 16384, 15);
143 const int32_t vtmp1_i = math_asr_s32(vout2_r * tw2_i + vout2_i * tw2_r + 16384, 15);
144 const int32_t vtmp2_r = math_asr_s32(vout3_r * tw3_r - vout3_i * tw3_i + 16384, 15);
145 const int32_t vtmp2_i = math_asr_s32(vout3_r * tw3_i + vout3_i * tw3_r + 16384, 15);
146
147 const int32_t vtmp5_r = vout0_r - vtmp1_r;
148 const int32_t vtmp5_i = vout0_i - vtmp1_i;
149 vout0_r += vtmp1_r;
150 vout0_i += vtmp1_i;
151 const int32_t vtmp3_r = vtmp0_r + vtmp2_r;
152 const int32_t vtmp3_i = vtmp0_i + vtmp2_i;
153 const int32_t vtmp4_r = vtmp0_r - vtmp2_r;
154 const int32_t vtmp4_i = vtmp0_i - vtmp2_i;
155 vout2_r = vout0_r - vtmp3_r;
156 vout2_i = vout0_i - vtmp3_i;
157
158 tw1 += stride * 2;
159 tw2 += stride * 4;
160 tw3 += stride * 6;
161 vout0_r += vtmp3_r;
162 vout0_i += vtmp3_i;
163
164 vout1_r = vtmp5_r + vtmp4_i;
165 vout1_i = vtmp5_i - vtmp4_r;
166 vout3_r = vtmp5_r - vtmp4_i;
167 vout3_i = vtmp5_i + vtmp4_r;
168
169 out0[0] = (int16_t) vout0_r;
170 out0[1] = (int16_t) vout0_i;
171 out1[0] = (int16_t) vout1_r;
172 out1[1] = (int16_t) vout1_i;
173 out2[0] = (int16_t) vout2_r;
174 out2[1] = (int16_t) vout2_i;
175 out3[0] = (int16_t) vout3_r;
176 out3[1] = (int16_t) vout3_i;
177 out0 += 2;
178 out1 += 2;
179 out2 += 2;
180 out3 += 2;
181 } while(--samples != 0);
182 }
183
184 class BFly4MicrokernelTester {
185 public:
samples(size_t samples)186 inline BFly4MicrokernelTester& samples(size_t samples) {
187 assert(samples != 0);
188 this->samples_ = samples;
189 return *this;
190 }
191
samples()192 inline size_t samples() const {
193 return this->samples_;
194 }
195
stride(uint32_t stride)196 inline BFly4MicrokernelTester& stride(uint32_t stride) {
197 this->stride_ = stride;
198 return *this;
199 }
200
stride()201 inline uint32_t stride() const {
202 return this->stride_;
203 }
204
iterations(size_t iterations)205 inline BFly4MicrokernelTester& iterations(size_t iterations) {
206 this->iterations_ = iterations;
207 return *this;
208 }
209
iterations()210 inline size_t iterations() const {
211 return this->iterations_;
212 }
213
Test(xnn_cs16_bfly4_ukernel_function bfly4)214 void Test(xnn_cs16_bfly4_ukernel_function bfly4) const {
215 std::random_device random_device;
216 auto rng = std::mt19937(random_device());
217 auto i16rng = std::bind(std::uniform_int_distribution<int16_t>(), std::ref(rng));
218 const size_t fft_size = (samples() == 1 ? 1 : (samples() * stride())) * 4; // 4 for bfly4.
219
220 // 256 complex numbers = fft_size * 2 = 512
221 std::vector<int16_t> y(fft_size * 2);
222 std::vector<int16_t> y_ref(fft_size * 2);
223
224 for (size_t iteration = 0; iteration < iterations(); iteration++) {
225 std::generate(y.begin(), y.end(), std::ref(i16rng));
226 y_ref = y;
227
228 // Compute reference results.
229 xnn_cs16_bfly4_reference(samples(), y_ref.data(), stride(), xnn_reference_table_fft256_twiddle);
230
231 // Call optimized micro-kernel.
232 bfly4(samples(), y.data(), stride(), xnn_reference_table_fft256_twiddle);
233
234 // Verify results.
235 for (size_t n = 0; n < fft_size * 2; n++) {
236 ASSERT_EQ(y[n], y_ref[n])
237 << "at sample " << n << " / " << fft_size
238 << "\nsamples " << samples()
239 << "\nstride " << stride();
240 }
241 }
242 }
243
244 private:
245 size_t samples_{1};
246 uint32_t stride_{1};
247 size_t iterations_{15};
248 };
249