1*c217d954SCole Faust /*
2*c217d954SCole Faust * Copyright (c) 2019-2020 Arm Limited.
3*c217d954SCole Faust *
4*c217d954SCole Faust * SPDX-License-Identifier: MIT
5*c217d954SCole Faust *
6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust *
13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust * copies or substantial portions of the Software.
15*c217d954SCole Faust *
16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust * SOFTWARE.
23*c217d954SCole Faust */
24*c217d954SCole Faust #include "DFT.h"
25*c217d954SCole Faust
26*c217d954SCole Faust #include "PadLayer.h"
27*c217d954SCole Faust #include "Permute.h"
28*c217d954SCole Faust #include "Reverse.h"
29*c217d954SCole Faust #include "SliceOperations.h"
30*c217d954SCole Faust #include "support/ToolchainSupport.h"
31*c217d954SCole Faust
32*c217d954SCole Faust #include <cmath>
33*c217d954SCole Faust
34*c217d954SCole Faust namespace arm_compute
35*c217d954SCole Faust {
36*c217d954SCole Faust namespace test
37*c217d954SCole Faust {
38*c217d954SCole Faust namespace validation
39*c217d954SCole Faust {
40*c217d954SCole Faust namespace reference
41*c217d954SCole Faust {
42*c217d954SCole Faust namespace
43*c217d954SCole Faust {
44*c217d954SCole Faust /** Performs an one dimensional DFT on a given real sequence.
45*c217d954SCole Faust *
46*c217d954SCole Faust * @param[in] src_ptr Pointer to the real input sequence.
47*c217d954SCole Faust * @param[in] N Size of input sequence.
48*c217d954SCole Faust * @param[out] dst_ptr Pointer to the complex output sequence.
49*c217d954SCole Faust * @param[out] K Size of the output sequence
50*c217d954SCole Faust */
51*c217d954SCole Faust template <typename T>
rdft_1d_step(const T * src_ptr,size_t N,T * dst_ptr,size_t K)52*c217d954SCole Faust void rdft_1d_step(const T *src_ptr, size_t N, T *dst_ptr, size_t K)
53*c217d954SCole Faust {
54*c217d954SCole Faust #if defined(_OPENMP)
55*c217d954SCole Faust #pragma omp parallel for
56*c217d954SCole Faust #endif /* _OPENMP */
57*c217d954SCole Faust for(unsigned int k = 0; k < K; ++k)
58*c217d954SCole Faust {
59*c217d954SCole Faust float Xr = 0;
60*c217d954SCole Faust float Xi = 0;
61*c217d954SCole Faust for(unsigned int n = 0; n < N; ++n)
62*c217d954SCole Faust {
63*c217d954SCole Faust const float alpha = (2 * M_PI * k * n) / N;
64*c217d954SCole Faust const float val_r = src_ptr[n];
65*c217d954SCole Faust // Assuming DFT from the R domain thus skipping imaginary calculations
66*c217d954SCole Faust Xr += val_r * cos(alpha);
67*c217d954SCole Faust Xi -= val_r * sin(alpha);
68*c217d954SCole Faust }
69*c217d954SCole Faust
70*c217d954SCole Faust dst_ptr[k * 2] = Xr;
71*c217d954SCole Faust dst_ptr[k * 2 + 1] = Xi;
72*c217d954SCole Faust }
73*c217d954SCole Faust }
74*c217d954SCole Faust
75*c217d954SCole Faust /** Performs an one dimensional DFT on a given complex sequence.
76*c217d954SCole Faust *
77*c217d954SCole Faust * @param[in] src_ptr Pointer to the complex input sequence.
78*c217d954SCole Faust * @param[out] dst_ptr Pointer to the complex output sequence.
79*c217d954SCole Faust * @param[in] N Size of the sequences
80*c217d954SCole Faust */
81*c217d954SCole Faust template <typename T>
dft_1d_step(const T * src_ptr,T * dst_ptr,size_t N)82*c217d954SCole Faust void dft_1d_step(const T *src_ptr, T *dst_ptr, size_t N)
83*c217d954SCole Faust {
84*c217d954SCole Faust #if defined(_OPENMP)
85*c217d954SCole Faust #pragma omp parallel for
86*c217d954SCole Faust #endif /* _OPENMP */
87*c217d954SCole Faust for(unsigned int k = 0; k < N; ++k)
88*c217d954SCole Faust {
89*c217d954SCole Faust float Xr = 0;
90*c217d954SCole Faust float Xi = 0;
91*c217d954SCole Faust for(unsigned int n = 0; n < N; ++n)
92*c217d954SCole Faust {
93*c217d954SCole Faust const float alpha = (2 * M_PI * k * n) / N;
94*c217d954SCole Faust const float val_r = src_ptr[2 * n];
95*c217d954SCole Faust const float val_i = src_ptr[2 * n + 1];
96*c217d954SCole Faust const float cos_alpha = cos(alpha);
97*c217d954SCole Faust const float sin_alpha = sin(alpha);
98*c217d954SCole Faust
99*c217d954SCole Faust Xr += val_r * cos_alpha + val_i * sin_alpha;
100*c217d954SCole Faust Xi += val_i * cos_alpha - val_r * sin_alpha;
101*c217d954SCole Faust }
102*c217d954SCole Faust
103*c217d954SCole Faust dst_ptr[k * 2] = Xr;
104*c217d954SCole Faust dst_ptr[k * 2 + 1] = Xi;
105*c217d954SCole Faust }
106*c217d954SCole Faust }
107*c217d954SCole Faust
108*c217d954SCole Faust /** Performs an one dimensional inverse DFT on a given real sequence.
109*c217d954SCole Faust *
110*c217d954SCole Faust * @param[in] src_ptr Pointer to the real input sequence.
111*c217d954SCole Faust * @param[in] K Size of input sequence.
112*c217d954SCole Faust * @param[out] dst_ptr Pointer to the complex output sequence.
113*c217d954SCole Faust * @param[out] N Size of the output sequence
114*c217d954SCole Faust */
115*c217d954SCole Faust template <typename T>
irdft_1d_step(const T * src_ptr,size_t K,T * dst_ptr,size_t N)116*c217d954SCole Faust void irdft_1d_step(const T *src_ptr, size_t K, T *dst_ptr, size_t N)
117*c217d954SCole Faust {
118*c217d954SCole Faust const bool is_odd = N % 2;
119*c217d954SCole Faust const unsigned int Nleft = N - K;
120*c217d954SCole Faust const int tail_start = is_odd ? K - 1 : K - 2;
121*c217d954SCole Faust #if defined(_OPENMP)
122*c217d954SCole Faust #pragma omp parallel for
123*c217d954SCole Faust #endif /* _OPENMP */
124*c217d954SCole Faust for(unsigned int n = 0; n < N; ++n)
125*c217d954SCole Faust {
126*c217d954SCole Faust float xr = 0;
127*c217d954SCole Faust for(unsigned int k = 0; k < K; ++k)
128*c217d954SCole Faust {
129*c217d954SCole Faust const float alpha = (2 * M_PI * k * n) / N;
130*c217d954SCole Faust xr += src_ptr[2 * k] * cos(alpha) - src_ptr[2 * k + 1] * sin(alpha);
131*c217d954SCole Faust }
132*c217d954SCole Faust
133*c217d954SCole Faust unsigned int j = tail_start;
134*c217d954SCole Faust for(unsigned int k = 0; k < Nleft; ++k)
135*c217d954SCole Faust {
136*c217d954SCole Faust const float alpha = (2 * M_PI * (k + K) * n) / N;
137*c217d954SCole Faust xr += src_ptr[2 * j] * cos(alpha) + src_ptr[2 * j + 1] * sin(alpha);
138*c217d954SCole Faust --j;
139*c217d954SCole Faust }
140*c217d954SCole Faust
141*c217d954SCole Faust dst_ptr[n] = xr;
142*c217d954SCole Faust }
143*c217d954SCole Faust }
144*c217d954SCole Faust
145*c217d954SCole Faust /** Performs an one dimensional inverse DFT on a given complex sequence.
146*c217d954SCole Faust *
147*c217d954SCole Faust * @param[in] src_ptr Pointer to the complex input sequence.
148*c217d954SCole Faust * @param[out] dst_ptr Pointer to the complex output sequence.
149*c217d954SCole Faust * @param[in] N Size of the sequences
150*c217d954SCole Faust */
151*c217d954SCole Faust template <typename T>
idft_1d_step(const T * src_ptr,T * dst_ptr,size_t N)152*c217d954SCole Faust void idft_1d_step(const T *src_ptr, T *dst_ptr, size_t N)
153*c217d954SCole Faust {
154*c217d954SCole Faust #if defined(_OPENMP)
155*c217d954SCole Faust #pragma omp parallel for
156*c217d954SCole Faust #endif /* _OPENMP */
157*c217d954SCole Faust for(unsigned int n = 0; n < N; ++n)
158*c217d954SCole Faust {
159*c217d954SCole Faust float xr = 0;
160*c217d954SCole Faust float xi = 0;
161*c217d954SCole Faust for(unsigned int k = 0; k < N; ++k)
162*c217d954SCole Faust {
163*c217d954SCole Faust const float alpha = (2 * M_PI * k * n) / N;
164*c217d954SCole Faust const float cos_alpha = cos(alpha);
165*c217d954SCole Faust const float sin_alpha = sin(alpha);
166*c217d954SCole Faust const float val_r = src_ptr[2 * k];
167*c217d954SCole Faust const float val_i = src_ptr[2 * k + 1];
168*c217d954SCole Faust
169*c217d954SCole Faust xr += val_r * cos_alpha - val_i * sin_alpha;
170*c217d954SCole Faust xi += val_i * cos_alpha + val_r * sin_alpha;
171*c217d954SCole Faust }
172*c217d954SCole Faust
173*c217d954SCole Faust dst_ptr[2 * n] = xr;
174*c217d954SCole Faust dst_ptr[2 * n + 1] = xi;
175*c217d954SCole Faust }
176*c217d954SCole Faust }
177*c217d954SCole Faust
178*c217d954SCole Faust template <typename T>
rdft_1d_core(const SimpleTensor<T> & src,FFTDirection direction,bool is_odd)179*c217d954SCole Faust SimpleTensor<T> rdft_1d_core(const SimpleTensor<T> &src, FFTDirection direction, bool is_odd)
180*c217d954SCole Faust {
181*c217d954SCole Faust // Performs only rdft
182*c217d954SCole Faust ARM_COMPUTE_ERROR_ON(direction == FFTDirection::Forward && src.num_channels() != 1);
183*c217d954SCole Faust ARM_COMPUTE_ERROR_ON(direction == FFTDirection::Inverse && src.num_channels() != 2);
184*c217d954SCole Faust
185*c217d954SCole Faust const unsigned int inverse_tail = is_odd ? 1 : 0;
186*c217d954SCole Faust const unsigned int N = src.shape()[0];
187*c217d954SCole Faust const unsigned int K = direction == FFTDirection::Forward ? N / 2 + 1 : (N - 1) * 2 + inverse_tail;
188*c217d954SCole Faust const unsigned int num_channels = direction == FFTDirection::Forward ? 2 : 1;
189*c217d954SCole Faust
190*c217d954SCole Faust TensorShape dst_shape = src.shape();
191*c217d954SCole Faust dst_shape.set(0, K);
192*c217d954SCole Faust
193*c217d954SCole Faust SimpleTensor<T> dst(dst_shape, src.data_type(), num_channels);
194*c217d954SCole Faust
195*c217d954SCole Faust const unsigned int upper_dims = src.shape().total_size_upper(1);
196*c217d954SCole Faust #if defined(_OPENMP)
197*c217d954SCole Faust #pragma omp parallel for
198*c217d954SCole Faust #endif /* _OPENMP */
199*c217d954SCole Faust for(unsigned int du = 0; du < upper_dims; ++du)
200*c217d954SCole Faust {
201*c217d954SCole Faust const T *src_row_ptr = src.data() + du * N * src.num_channels();
202*c217d954SCole Faust T *dst_row_ptr = dst.data() + du * K * dst.num_channels();
203*c217d954SCole Faust direction == FFTDirection::Forward ? rdft_1d_step(src_row_ptr, N, dst_row_ptr, K) : irdft_1d_step(src_row_ptr, N, dst_row_ptr, K);
204*c217d954SCole Faust }
205*c217d954SCole Faust
206*c217d954SCole Faust return dst;
207*c217d954SCole Faust }
208*c217d954SCole Faust
209*c217d954SCole Faust template <typename T>
dft_1d_core(const SimpleTensor<T> & src,FFTDirection direction)210*c217d954SCole Faust SimpleTensor<T> dft_1d_core(const SimpleTensor<T> &src, FFTDirection direction)
211*c217d954SCole Faust {
212*c217d954SCole Faust ARM_COMPUTE_ERROR_ON(src.num_channels() != 2);
213*c217d954SCole Faust
214*c217d954SCole Faust const unsigned int N = src.shape()[0];
215*c217d954SCole Faust
216*c217d954SCole Faust SimpleTensor<T> dst(src.shape(), src.data_type(), src.num_channels());
217*c217d954SCole Faust
218*c217d954SCole Faust const unsigned int upper_dims = src.shape().total_size_upper(1);
219*c217d954SCole Faust #if defined(_OPENMP)
220*c217d954SCole Faust #pragma omp parallel for
221*c217d954SCole Faust #endif /* _OPENMP */
222*c217d954SCole Faust for(unsigned int du = 0; du < upper_dims; ++du)
223*c217d954SCole Faust {
224*c217d954SCole Faust const T *src_row_ptr = src.data() + du * N * src.num_channels();
225*c217d954SCole Faust T *dst_row_ptr = dst.data() + du * N * dst.num_channels();
226*c217d954SCole Faust direction == FFTDirection::Forward ? dft_1d_step(src_row_ptr, dst_row_ptr, N) : idft_1d_step(src_row_ptr, dst_row_ptr, N);
227*c217d954SCole Faust }
228*c217d954SCole Faust
229*c217d954SCole Faust return dst;
230*c217d954SCole Faust }
231*c217d954SCole Faust
232*c217d954SCole Faust /** Scale a tensor by a given scaling factor.
233*c217d954SCole Faust *
234*c217d954SCole Faust * @param[in,out] tensor Tensor to scale.
235*c217d954SCole Faust * @param[in] scaling_factor Scaling to scale the tensor data with.
236*c217d954SCole Faust */
237*c217d954SCole Faust template <typename T>
scale(SimpleTensor<T> & tensor,T scaling_factor)238*c217d954SCole Faust void scale(SimpleTensor<T> &tensor, T scaling_factor)
239*c217d954SCole Faust {
240*c217d954SCole Faust const int total_elements = tensor.num_elements() * tensor.num_channels();
241*c217d954SCole Faust T *data_ptr = tensor.data();
242*c217d954SCole Faust #if defined(_OPENMP)
243*c217d954SCole Faust #pragma omp parallel for
244*c217d954SCole Faust #endif /* _OPENMP */
245*c217d954SCole Faust for(int i = 0; i < total_elements; ++i)
246*c217d954SCole Faust {
247*c217d954SCole Faust data_ptr[i] /= scaling_factor;
248*c217d954SCole Faust }
249*c217d954SCole Faust }
250*c217d954SCole Faust
251*c217d954SCole Faust /** Performs a complex element-wise multiplication with reduction across the channels axis.
252*c217d954SCole Faust *
253*c217d954SCole Faust * @param[in] input Input tensor.
254*c217d954SCole Faust * @param[in] weights Weights tensor.
255*c217d954SCole Faust *
256*c217d954SCole Faust * @return Output tensor.
257*c217d954SCole Faust */
258*c217d954SCole Faust template <typename T>
complex_mul_and_reduce(const SimpleTensor<T> & input,const SimpleTensor<T> & weights)259*c217d954SCole Faust SimpleTensor<T> complex_mul_and_reduce(const SimpleTensor<T> &input, const SimpleTensor<T> &weights)
260*c217d954SCole Faust {
261*c217d954SCole Faust const uint32_t W = input.shape().x();
262*c217d954SCole Faust const uint32_t H = input.shape().y();
263*c217d954SCole Faust const uint32_t Ci = input.shape().z();
264*c217d954SCole Faust const uint32_t Co = weights.shape()[3];
265*c217d954SCole Faust const uint32_t N = input.shape().total_size() / (W * H * Ci);
266*c217d954SCole Faust
267*c217d954SCole Faust TensorShape output_shape = input.shape();
268*c217d954SCole Faust output_shape.set(2, Co);
269*c217d954SCole Faust SimpleTensor<T> dst(output_shape, input.data_type(), input.num_channels());
270*c217d954SCole Faust
271*c217d954SCole Faust // dst memory to zero
272*c217d954SCole Faust const auto total_element_count = dst.num_channels() * dst.num_elements();
273*c217d954SCole Faust std::fill_n(dst.data(), total_element_count, 0);
274*c217d954SCole Faust
275*c217d954SCole Faust for(uint32_t b = 0; b < N; ++b)
276*c217d954SCole Faust {
277*c217d954SCole Faust for(uint32_t co = 0; co < Co; ++co)
278*c217d954SCole Faust {
279*c217d954SCole Faust for(uint32_t ci = 0; ci < Ci; ++ci)
280*c217d954SCole Faust {
281*c217d954SCole Faust for(uint32_t h = 0; h < H; ++h)
282*c217d954SCole Faust {
283*c217d954SCole Faust for(uint32_t w = 0; w < W; ++w)
284*c217d954SCole Faust {
285*c217d954SCole Faust const uint32_t i_index = w + h * W + ci * H * W + b * H * W * Ci;
286*c217d954SCole Faust const uint32_t w_index = w + h * W + ci * H * W + co * H * W * Ci;
287*c217d954SCole Faust const uint32_t o_index = w + h * W + co * H * W + b * H * W * Co;
288*c217d954SCole Faust const Coordinates i_coords = index2coords(input.shape(), i_index);
289*c217d954SCole Faust const Coordinates w_coords = index2coords(weights.shape(), w_index);
290*c217d954SCole Faust const Coordinates o_coords = index2coords(dst.shape(), o_index);
291*c217d954SCole Faust
292*c217d954SCole Faust auto i_ptr = static_cast<const T *>(input(i_coords));
293*c217d954SCole Faust auto w_ptr = static_cast<const T *>(weights(w_coords));
294*c217d954SCole Faust auto o_ptr = static_cast<T *>(dst(o_coords));
295*c217d954SCole Faust
296*c217d954SCole Faust const T Rin = i_ptr[0];
297*c217d954SCole Faust const T Iin = i_ptr[1];
298*c217d954SCole Faust const T Rw = w_ptr[0];
299*c217d954SCole Faust const T Iw = w_ptr[1];
300*c217d954SCole Faust
301*c217d954SCole Faust o_ptr[0] += Rin * Rw - Iin * Iw;
302*c217d954SCole Faust o_ptr[1] += Rin * Iw + Rw * Iin;
303*c217d954SCole Faust }
304*c217d954SCole Faust }
305*c217d954SCole Faust }
306*c217d954SCole Faust }
307*c217d954SCole Faust }
308*c217d954SCole Faust return dst;
309*c217d954SCole Faust }
310*c217d954SCole Faust } // namespace
311*c217d954SCole Faust
312*c217d954SCole Faust template <typename T>
rdft_1d(const SimpleTensor<T> & src)313*c217d954SCole Faust SimpleTensor<T> rdft_1d(const SimpleTensor<T> &src)
314*c217d954SCole Faust {
315*c217d954SCole Faust return rdft_1d_core(src, FFTDirection::Forward, false);
316*c217d954SCole Faust }
317*c217d954SCole Faust
318*c217d954SCole Faust template <typename T>
ridft_1d(const SimpleTensor<T> & src,bool is_odd)319*c217d954SCole Faust SimpleTensor<T> ridft_1d(const SimpleTensor<T> &src, bool is_odd)
320*c217d954SCole Faust {
321*c217d954SCole Faust auto dst = rdft_1d_core(src, FFTDirection::Inverse, is_odd);
322*c217d954SCole Faust
323*c217d954SCole Faust const T scaling_factor = T(dst.shape()[0]);
324*c217d954SCole Faust scale(dst, scaling_factor);
325*c217d954SCole Faust
326*c217d954SCole Faust return dst;
327*c217d954SCole Faust }
328*c217d954SCole Faust
329*c217d954SCole Faust template <typename T>
dft_1d(const SimpleTensor<T> & src,FFTDirection direction)330*c217d954SCole Faust SimpleTensor<T> dft_1d(const SimpleTensor<T> &src, FFTDirection direction)
331*c217d954SCole Faust {
332*c217d954SCole Faust auto dst = dft_1d_core(src, direction);
333*c217d954SCole Faust if(direction == FFTDirection::Inverse)
334*c217d954SCole Faust {
335*c217d954SCole Faust const T scaling_factor = T(dst.shape()[0]);
336*c217d954SCole Faust scale(dst, scaling_factor);
337*c217d954SCole Faust }
338*c217d954SCole Faust return dst;
339*c217d954SCole Faust }
340*c217d954SCole Faust
341*c217d954SCole Faust template <typename T>
rdft_2d(const SimpleTensor<T> & src)342*c217d954SCole Faust SimpleTensor<T> rdft_2d(const SimpleTensor<T> &src)
343*c217d954SCole Faust {
344*c217d954SCole Faust ARM_COMPUTE_ERROR_ON(src.num_channels() != 1);
345*c217d954SCole Faust constexpr FFTDirection direction = FFTDirection::Forward;
346*c217d954SCole Faust
347*c217d954SCole Faust auto first_pass = rdft_1d_core(src, direction, false);
348*c217d954SCole Faust auto transposed = permute(first_pass, PermutationVector(1U, 0U));
349*c217d954SCole Faust auto second_pass = dft_1d_core(transposed, direction);
350*c217d954SCole Faust return permute(second_pass, PermutationVector(1U, 0U));
351*c217d954SCole Faust }
352*c217d954SCole Faust
353*c217d954SCole Faust template <typename T>
ridft_2d(const SimpleTensor<T> & src,bool is_odd)354*c217d954SCole Faust SimpleTensor<T> ridft_2d(const SimpleTensor<T> &src, bool is_odd)
355*c217d954SCole Faust {
356*c217d954SCole Faust ARM_COMPUTE_ERROR_ON(src.num_channels() != 2);
357*c217d954SCole Faust constexpr FFTDirection direction = FFTDirection::Inverse;
358*c217d954SCole Faust
359*c217d954SCole Faust auto transposed = permute(src, PermutationVector(1U, 0U));
360*c217d954SCole Faust auto first_pass = dft_1d_core(transposed, direction);
361*c217d954SCole Faust auto transposed_2 = permute(first_pass, PermutationVector(1U, 0U));
362*c217d954SCole Faust auto dst = rdft_1d_core(transposed_2, direction, is_odd);
363*c217d954SCole Faust
364*c217d954SCole Faust const T scaling_factor = T(dst.shape()[0] * dst.shape()[1]);
365*c217d954SCole Faust scale(dst, scaling_factor);
366*c217d954SCole Faust return dst;
367*c217d954SCole Faust }
368*c217d954SCole Faust
369*c217d954SCole Faust template <typename T>
dft_2d(const SimpleTensor<T> & src,FFTDirection direction)370*c217d954SCole Faust SimpleTensor<T> dft_2d(const SimpleTensor<T> &src, FFTDirection direction)
371*c217d954SCole Faust {
372*c217d954SCole Faust ARM_COMPUTE_ERROR_ON(src.num_channels() != 2);
373*c217d954SCole Faust
374*c217d954SCole Faust if(direction == FFTDirection::Forward)
375*c217d954SCole Faust {
376*c217d954SCole Faust auto first_pass = dft_1d_core(src, direction);
377*c217d954SCole Faust auto transposed = permute(first_pass, PermutationVector(1U, 0U));
378*c217d954SCole Faust auto second_pass = dft_1d_core(transposed, direction);
379*c217d954SCole Faust return permute(second_pass, PermutationVector(1U, 0U));
380*c217d954SCole Faust }
381*c217d954SCole Faust else
382*c217d954SCole Faust {
383*c217d954SCole Faust auto transposed = permute(src, PermutationVector(1U, 0U));
384*c217d954SCole Faust auto first_pass = dft_1d_core(transposed, direction);
385*c217d954SCole Faust auto transposed_2 = permute(first_pass, PermutationVector(1U, 0U));
386*c217d954SCole Faust auto dst = dft_1d_core(transposed_2, direction);
387*c217d954SCole Faust
388*c217d954SCole Faust const T scaling_factor = T(dst.shape()[0] * dst.shape()[1]);
389*c217d954SCole Faust scale(dst, scaling_factor);
390*c217d954SCole Faust
391*c217d954SCole Faust return dst;
392*c217d954SCole Faust }
393*c217d954SCole Faust }
394*c217d954SCole Faust
395*c217d954SCole Faust template <typename T>
conv2d_dft(const SimpleTensor<T> & src,const SimpleTensor<T> & w,const PadStrideInfo & conv_info)396*c217d954SCole Faust SimpleTensor<T> conv2d_dft(const SimpleTensor<T> &src, const SimpleTensor<T> &w, const PadStrideInfo &conv_info)
397*c217d954SCole Faust {
398*c217d954SCole Faust // Pad input to full padding
399*c217d954SCole Faust const PaddingList padding_in = { { 0, w.shape()[0] - 1 }, { 0, w.shape()[1] - 1 } };
400*c217d954SCole Faust auto padded_src = pad_layer(src, padding_in);
401*c217d954SCole Faust
402*c217d954SCole Faust // Flip weights
403*c217d954SCole Faust std::vector<uint32_t> axis_v = { 0, 1 };
404*c217d954SCole Faust SimpleTensor<uint32_t> axis{ TensorShape(2U), DataType::U32 };
405*c217d954SCole Faust std::copy(axis_v.begin(), axis_v.begin() + axis.shape().x(), axis.data());
406*c217d954SCole Faust auto flipped_w = reverse(w, axis);
407*c217d954SCole Faust
408*c217d954SCole Faust // Pad weights to have the same size as input
409*c217d954SCole Faust const PaddingList paddings_w = { { 0, src.shape()[0] - 1 }, { 0, src.shape()[1] - 1 } };
410*c217d954SCole Faust auto padded_w = pad_layer(flipped_w, paddings_w);
411*c217d954SCole Faust
412*c217d954SCole Faust // Transform input and weights to frequency domain
413*c217d954SCole Faust auto Fsrc = rdft_2d(padded_src);
414*c217d954SCole Faust auto Fw = rdft_2d(padded_w);
415*c217d954SCole Faust
416*c217d954SCole Faust // Perform dot product
417*c217d954SCole Faust auto Fdst = complex_mul_and_reduce(Fsrc, Fw);
418*c217d954SCole Faust
419*c217d954SCole Faust // Transform output back to frequency domain
420*c217d954SCole Faust auto conv_res = ridft_2d(Fdst);
421*c217d954SCole Faust
422*c217d954SCole Faust // Slice output
423*c217d954SCole Faust const int start_left = w.shape().x() - conv_info.pad_left() - 1;
424*c217d954SCole Faust const int start_top = w.shape().y() - conv_info.pad_top() - 1;
425*c217d954SCole Faust const int end_right = conv_res.shape().x() - (w.shape().x() - conv_info.pad_right() - 1);
426*c217d954SCole Faust const int end_botton = conv_res.shape().y() - (w.shape().y() - conv_info.pad_bottom() - 1);
427*c217d954SCole Faust return slice(conv_res, Coordinates(start_left, start_top), Coordinates(end_right, end_botton));
428*c217d954SCole Faust }
429*c217d954SCole Faust
430*c217d954SCole Faust // FP32
431*c217d954SCole Faust template SimpleTensor<float> rdft_1d(const SimpleTensor<float> &src);
432*c217d954SCole Faust template SimpleTensor<float> ridft_1d(const SimpleTensor<float> &src, bool is_odd);
433*c217d954SCole Faust template SimpleTensor<float> dft_1d(const SimpleTensor<float> &src, FFTDirection direction);
434*c217d954SCole Faust
435*c217d954SCole Faust template SimpleTensor<float> rdft_2d(const SimpleTensor<float> &src);
436*c217d954SCole Faust template SimpleTensor<float> ridft_2d(const SimpleTensor<float> &src, bool is_odd);
437*c217d954SCole Faust template SimpleTensor<float> dft_2d(const SimpleTensor<float> &src, FFTDirection direction);
438*c217d954SCole Faust
439*c217d954SCole Faust template SimpleTensor<float> conv2d_dft(const SimpleTensor<float> &src, const SimpleTensor<float> &w, const PadStrideInfo &conv_info);
440*c217d954SCole Faust
441*c217d954SCole Faust // FP16
442*c217d954SCole Faust template SimpleTensor<half> rdft_1d(const SimpleTensor<half> &src);
443*c217d954SCole Faust template SimpleTensor<half> ridft_1d(const SimpleTensor<half> &src, bool is_odd);
444*c217d954SCole Faust template SimpleTensor<half> dft_1d(const SimpleTensor<half> &src, FFTDirection direction);
445*c217d954SCole Faust
446*c217d954SCole Faust template SimpleTensor<half> rdft_2d(const SimpleTensor<half> &src);
447*c217d954SCole Faust template SimpleTensor<half> ridft_2d(const SimpleTensor<half> &src, bool is_odd);
448*c217d954SCole Faust template SimpleTensor<half> dft_2d(const SimpleTensor<half> &src, FFTDirection direction);
449*c217d954SCole Faust
450*c217d954SCole Faust template SimpleTensor<half> conv2d_dft(const SimpleTensor<half> &src, const SimpleTensor<half> &w, const PadStrideInfo &conv_info);
451*c217d954SCole Faust } // namespace reference
452*c217d954SCole Faust } // namespace validation
453*c217d954SCole Faust } // namespace test
454*c217d954SCole Faust } // namespace arm_compute
455