xref: /aosp_15_r20/external/ComputeLibrary/src/core/CL/cl_kernels/common/fft.cl (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1/*
2 * Copyright (c) 2019-2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
26#if defined(DATA_TYPE)
27/** Calculates and applies the twiddle factor to a given input.
28 *
29 * @param[in]     phi   The angle.
30 * @param[in,out] input The input on which the factor should be applied.
31 */
32#define TWIDDLE_FACTOR_MULTIPLICATION(phi, input)  \
33    {                                              \
34        VEC_DATA_TYPE(DATA_TYPE, 2)                \
35        w, tmp;                                    \
36        w.x   = cos(phi);                          \
37        w.y   = sin(phi);                          \
38        tmp.x = (w.x * input.x) - (w.y * input.y); \
39        tmp.y = (w.x * input.y) + (w.y * input.x); \
40        input = tmp;                               \
41    }
42
43/** Computes radix-2 butterfly unit.
44 *
45 * @param[in,out] c0 Complex input 0.
46 * @param[in,out] c1 Complex input 1.
47 */
48#define DFT_2(c0, c1)               \
49    {                               \
50        VEC_DATA_TYPE(DATA_TYPE, 2) \
51        v0;                         \
52        v0 = c0;                    \
53        c0 = v0 + c1;               \
54        c1 = v0 - c1;               \
55    }
56
57// radix-3 butterfly unit factors
58#define SQRT3DIV2 0.86602540378443f
59
60/** Computes radix-3 butterfly unit.
61 *
62 * @param[in,out] c0 Complex input 0.
63 * @param[in,out] c1 Complex input 1.
64 * @param[in,out] c2 Complex input 2.
65 */
66#define DFT_3(c0, c1, c2)                             \
67    {                                                 \
68        VEC_DATA_TYPE(DATA_TYPE, 2)                   \
69        v0 = c1 + c2;                                 \
70        VEC_DATA_TYPE(DATA_TYPE, 2)                   \
71        v1   = c1 - c2;                               \
72        c1.x = c0.x - 0.5f * v0.x + v1.y * SQRT3DIV2; \
73        c1.y = c0.y - 0.5f * v0.y - v1.x * SQRT3DIV2; \
74        c2.x = c0.x - 0.5f * v0.x - v1.y * SQRT3DIV2; \
75        c2.y = c0.y - 0.5f * v0.y + v1.x * SQRT3DIV2; \
76        c0   = c0 + v0;                               \
77    }
78
79/**Computes radix-4 butterfly unit.
80 *
81 * @param[in,out] c0 Complex input 0.
82 * @param[in,out] c1 Complex input 1.
83 * @param[in,out] c2 Complex input 2.
84 * @param[in,out] c3 Complex input 3.
85 */
86#define DFT_4(c0, c1, c2, c3)       \
87    {                               \
88        VEC_DATA_TYPE(DATA_TYPE, 2) \
89        v0, v1, v2, v3;             \
90        v0   = c0 + c2;             \
91        v1   = c1 + c3;             \
92        v2   = c0 - c2;             \
93        v3.x = c1.y - c3.y;         \
94        v3.y = c3.x - c1.x;         \
95        c0   = v0 + v1;             \
96        c2   = v0 - v1;             \
97        c1   = v2 + v3;             \
98        c3   = v2 - v3;             \
99    }
100
101// radix-5 butterfly unit factors
102#define W5_A (DATA_TYPE)0.30901699437494f
103#define W5_B (DATA_TYPE)0.95105651629515f
104#define W5_C (DATA_TYPE)0.80901699437494f
105#define W5_D (DATA_TYPE)0.58778525229247f
106
107/** Computes radix-5 butterfly unit.
108 *
109 * @param[in,out] c0 Complex input 0.
110 * @param[in,out] c1 Complex input 1.
111 * @param[in,out] c2 Complex input 2.
112 * @param[in,out] c3 Complex input 3.
113 * @param[in,out] c4 Complex input 4.
114 */
115#define DFT_5(c0, c1, c2, c3, c4)                                  \
116    {                                                              \
117        VEC_DATA_TYPE(DATA_TYPE, 2)                                \
118        v0, v1, v2, v3, v4;                                        \
119        v0 = c0;                                                   \
120        v1 = W5_A * (c1 + c4) - W5_C * (c2 + c3);                  \
121        v2 = W5_C * (c1 + c4) - W5_A * (c2 + c3);                  \
122        v3 = W5_D * (c1 - c4) - W5_B * (c2 - c3);                  \
123        v4 = W5_B * (c1 - c4) + W5_D * (c2 - c3);                  \
124        c0 = v0 + c1 + c2 + c3 + c4;                               \
125        c1 = v0 + v1 + (VEC_DATA_TYPE(DATA_TYPE, 2))(v4.y, -v4.x); \
126        c2 = v0 - v2 + (VEC_DATA_TYPE(DATA_TYPE, 2))(v3.y, -v3.x); \
127        c3 = v0 - v2 + (VEC_DATA_TYPE(DATA_TYPE, 2))(-v3.y, v3.x); \
128        c4 = v0 + v1 + (VEC_DATA_TYPE(DATA_TYPE, 2))(-v4.y, v4.x); \
129    }
130
131// radix-7 butterfly unit factors
132#define W7_A (DATA_TYPE)0.62348980185873f
133#define W7_B (DATA_TYPE)0.78183148246802f
134#define W7_C (DATA_TYPE)0.22252093395631f
135#define W7_D (DATA_TYPE)0.97492791218182f
136#define W7_E (DATA_TYPE)0.90096886790241f
137#define W7_F (DATA_TYPE)0.43388373911755f
138
139/** Computes radix-7 butterfly unit.
140 *
141 * @param[in,out] c0 Complex input 0.
142 * @param[in,out] c1 Complex input 1.
143 * @param[in,out] c2 Complex input 2.
144 * @param[in,out] c3 Complex input 3.
145 * @param[in,out] c4 Complex input 4.
146 * @param[in,out] c5 Complex input 5.
147 * @param[in,out] c6 Complex input 6.
148 */
149#define DFT_7(c0, c1, c2, c3, c4, c5, c6)                            \
150    {                                                                \
151        VEC_DATA_TYPE(DATA_TYPE, 2)                                  \
152        v0, v1, v2, v3, v4, v5, v6;                                  \
153        v0 = c0;                                                     \
154        v1 = W7_A * (c1 + c6) - W7_C * (c2 + c5) - W7_E * (c3 + c4); \
155        v2 = W7_C * (c1 + c6) + W7_E * (c2 + c5) - W7_A * (c3 + c4); \
156        v3 = W7_E * (c1 + c6) - W7_A * (c2 + c5) + W7_C * (c3 + c4); \
157        v4 = W7_B * (c1 - c6) + W7_D * (c2 - c5) + W7_F * (c3 - c4); \
158        v5 = W7_D * (c1 - c6) - W7_F * (c2 - c5) - W7_B * (c3 - c4); \
159        v6 = W7_F * (c1 - c6) - W7_B * (c2 - c5) + W7_D * (c3 - c4); \
160        c0 = v0 + c1 + c2 + c3 + c4 + c5 + c6;                       \
161        c1 = v0 + v1 + (VEC_DATA_TYPE(DATA_TYPE, 2))(v4.y, -v4.x);   \
162        c2 = v0 - v2 + (VEC_DATA_TYPE(DATA_TYPE, 2))(v5.y, -v5.x);   \
163        c3 = v0 - v3 + (VEC_DATA_TYPE(DATA_TYPE, 2))(v6.y, -v6.x);   \
164        c4 = v0 - v3 + (VEC_DATA_TYPE(DATA_TYPE, 2))(-v6.y, v6.x);   \
165        c5 = v0 - v2 + (VEC_DATA_TYPE(DATA_TYPE, 2))(-v5.y, v5.x);   \
166        c6 = v0 + v1 + (VEC_DATA_TYPE(DATA_TYPE, 2))(-v4.y, v4.x);   \
167    }
168
169/** Computes radix-8 butterfly unit.
170 *
171 * @param[in,out] c0 Complex input 0.
172 * @param[in,out] c1 Complex input 1.
173 * @param[in,out] c2 Complex input 2.
174 * @param[in,out] c3 Complex input 3.
175 * @param[in,out] c4 Complex input 4.
176 * @param[in,out] c5 Complex input 5.
177 * @param[in,out] c6 Complex input 6.
178 * @param[in,out] c7 Complex input 7.
179 */
180#define DFT_8(c0, c1, c2, c3, c4, c5, c6, c7) \
181    {                                         \
182        VEC_DATA_TYPE(DATA_TYPE, 2)           \
183        v0, v1, v2, v3, v4, v5, v6, v7;       \
184        VEC_DATA_TYPE(DATA_TYPE, 2)           \
185        s0, s1, s2, s3, s4, s5, s6, s7;       \
186        VEC_DATA_TYPE(DATA_TYPE, 2)           \
187        t0, t1, t2;                           \
188        v0   = c0 + c4;                       \
189        v1   = c1 + c5;                       \
190        v2   = c2 + c6;                       \
191        v3   = c3 + c7;                       \
192        v4   = c0 - c4;                       \
193        v5   = c1 - c5;                       \
194        v6   = c2 - c6;                       \
195        v7   = c3 - c7;                       \
196        s0   = v0 + v2;                       \
197        s1   = v1 + v3;                       \
198        s2   = v0 - v2;                       \
199        s3   = v1 - v3;                       \
200        s4.x = v4.x - v6.y;                   \
201        s4.y = v4.y + v6.x;                   \
202        s5.x = v5.x - v7.y;                   \
203        s5.y = v5.y + v7.x;                   \
204        s6.x = v4.x + v6.y;                   \
205        s6.y = v4.y - v6.x;                   \
206        s7.x = v5.x + v7.y;                   \
207        s7.y = v5.y - v7.x;                   \
208        t0.x = -s3.y;                         \
209        t0.y = s3.x;                          \
210        t1.x = M_SQRT1_2_F * (s5.x - s5.y);   \
211        t1.y = M_SQRT1_2_F * (s5.x + s5.y);   \
212        t2.x = -M_SQRT1_2_F * (s7.x + s7.y);  \
213        t2.y = M_SQRT1_2_F * (s7.x - s7.y);   \
214        c0   = s0 + s1;                       \
215        c1   = s6 - t2;                       \
216        c2   = s2 - t0;                       \
217        c3   = s4 - t1;                       \
218        c4   = s0 - s1;                       \
219        c5   = s6 + t2;                       \
220        c6   = s2 + t0;                       \
221        c7   = s4 + t1;                       \
222    }
223
224/** Computes the first stage of a radix-2 DFT on axis 0.
225 *
226 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
227 *
228 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
229 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
230 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
231 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
232 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
233 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
234 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
235 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
236 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
237 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
238 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
239 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
240 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
241 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
242 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
243 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
244 */
245__kernel void fft_radix_2_first_stage_axis_0(
246    TENSOR3D_DECLARATION(input)
247#ifndef IN_PLACE
248    ,
249    TENSOR3D_DECLARATION(output)
250#endif /* not IN_PLACE */
251)
252{
253    // Get tensor pointers
254    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
255#ifdef IN_PLACE
256    Tensor3D output = input;
257#else  /* IN_PLACE */
258    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
259#endif /* IN_PLACE */
260
261    // Load two complex input values
262    VEC_DATA_TYPE(DATA_TYPE, 4)
263    data = vload4(0, (__global DATA_TYPE *)input.ptr);
264
265    // Compute DFT N = 2
266    DFT_2(data.s01, data.s23);
267
268    // Store two complex output values
269    vstore4(data, 0, (__global DATA_TYPE *)output.ptr);
270}
271
272/** Computes the first stage of a radix-2 DFT on axis 1.
273 *
274 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
275 *
276 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
277 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
278 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
279 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
280 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
281 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
282 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
283 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
284 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
285 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
286 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
287 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
288 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
289 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
290 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
291 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
292 */
293__kernel void fft_radix_2_first_stage_axis_1(
294    TENSOR3D_DECLARATION(input)
295#ifndef IN_PLACE
296    ,
297    TENSOR3D_DECLARATION(output)
298#endif /* not IN_PLACE */
299)
300{
301    // Get tensor pointers
302    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
303#ifdef IN_PLACE
304    Tensor3D output = input;
305#else  /* IN_PLACE */
306    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
307#endif /* IN_PLACE */
308
309    // Load two complex input values
310    VEC_DATA_TYPE(DATA_TYPE, 2)
311    data1 = vload2(0, (__global DATA_TYPE *)input.ptr);
312    VEC_DATA_TYPE(DATA_TYPE, 2)
313    data2 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 1, 0));
314
315    // Compute DFT N = 2
316    DFT_2(data1, data2);
317
318    // Store two complex output values
319    vstore2(data1, 0, (__global DATA_TYPE *)output.ptr);
320    vstore2(data2, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 1, 0));
321}
322
323/** Computes the first stage of a radix-3 DFT on axis 0.
324 *
325 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
326 *
327 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
328 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
329 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
330 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
331 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
332 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
333 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
334 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
335 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
336 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
337 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
338 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
339 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
340 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
341 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
342 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
343 */
344__kernel void fft_radix_3_first_stage_axis_0(
345    TENSOR3D_DECLARATION(input)
346#ifndef IN_PLACE
347    ,
348    TENSOR3D_DECLARATION(output)
349#endif /* not IN_PLACE */
350)
351{
352    // Get tensor pointers
353    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
354#ifdef IN_PLACE
355    Tensor3D output = input;
356#else  /* IN_PLACE */
357    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
358#endif /* IN_PLACE */
359
360    // Load three complex input values
361    VEC_DATA_TYPE(DATA_TYPE, 4)
362    data0 = vload4(0, (__global DATA_TYPE *)input.ptr);
363    VEC_DATA_TYPE(DATA_TYPE, 2)
364    data1 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 2, 0, 0));
365
366    // Compute DFT N = 3
367    DFT_3(data0.s01, data0.s23, data1.s01);
368
369    // Store three complex output values
370    vstore4(data0, 0, (__global DATA_TYPE *)output.ptr);
371    vstore2(data1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 2, 0, 0));
372}
373
374/** Computes the first stage of a radix-3 DFT on axis 1.
375 *
376 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
377 *
378 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
379 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
380 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
381 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
382 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
383 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
384 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
385 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
386 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
387 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
388 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
389 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
390 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
391 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
392 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
393 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
394 */
395__kernel void fft_radix_3_first_stage_axis_1(
396    TENSOR3D_DECLARATION(input)
397#ifndef IN_PLACE
398    ,
399    TENSOR3D_DECLARATION(output)
400#endif /* not IN_PLACE */
401)
402{
403    // Get tensor pointers
404    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
405#ifdef IN_PLACE
406    Tensor3D output = input;
407#else  /* IN_PLACE */
408    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
409#endif /* IN_PLACE */
410
411    // Load three complex input values
412    VEC_DATA_TYPE(DATA_TYPE, 2)
413    data0 = vload2(0, (__global DATA_TYPE *)input.ptr);
414    VEC_DATA_TYPE(DATA_TYPE, 2)
415    data1 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 1, 0));
416    VEC_DATA_TYPE(DATA_TYPE, 2)
417    data2 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 2, 0));
418
419    // Compute DFT N = 3
420    DFT_3(data0, data1, data2);
421
422    // Store three complex output values
423    vstore2(data0, 0, (__global DATA_TYPE *)output.ptr);
424    vstore2(data1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 1, 0));
425    vstore2(data2, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 2, 0));
426}
427
428/** Computes the first stage of a radix-4 DFT on axis 0.
429 *
430 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
431 *
432 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
433 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
434 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
435 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
436 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
437 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
438 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
439 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
440 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
441 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
442 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
443 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
444 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
445 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
446 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
447 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
448 */
449__kernel void fft_radix_4_first_stage_axis_0(
450    TENSOR3D_DECLARATION(input)
451#ifndef IN_PLACE
452    ,
453    TENSOR3D_DECLARATION(output)
454#endif /* not IN_PLACE */
455)
456{
457    // Get tensor pointers
458    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
459#ifdef IN_PLACE
460    Tensor3D output = input;
461#else  /* IN_PLACE */
462    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
463#endif /* IN_PLACE */
464
465    // Load four complex input values
466    VEC_DATA_TYPE(DATA_TYPE, 8)
467    data = vload8(0, (__global DATA_TYPE *)input.ptr);
468
469    // Compute DFT N = 4
470    DFT_4(data.s01, data.s23, data.s45, data.s67);
471
472    // Store four complex output values
473    vstore8(data, 0, (__global DATA_TYPE *)output.ptr);
474}
475
476/** Computes the first stage of a radix-4 DFT on axis 1.
477 *
478 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
479 *
480 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
481 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
482 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
483 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
484 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
485 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
486 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
487 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
488 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
489 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
490 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
491 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
492 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
493 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
494 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
495 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
496 */
497__kernel void fft_radix_4_first_stage_axis_1(
498    TENSOR3D_DECLARATION(input)
499#ifndef IN_PLACE
500    ,
501    TENSOR3D_DECLARATION(output)
502#endif /* not IN_PLACE */
503)
504{
505    // Get tensor pointers
506    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
507#ifdef IN_PLACE
508    Tensor3D output = input;
509#else  /* IN_PLACE */
510    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
511#endif /* IN_PLACE */
512
513    // Load four complex input values
514    VEC_DATA_TYPE(DATA_TYPE, 2)
515    data0 = vload2(0, (__global DATA_TYPE *)input.ptr);
516    VEC_DATA_TYPE(DATA_TYPE, 2)
517    data1 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 1, 0));
518    VEC_DATA_TYPE(DATA_TYPE, 2)
519    data2 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 2, 0));
520    VEC_DATA_TYPE(DATA_TYPE, 2)
521    data3 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 3, 0));
522
523    // Compute DFT N = 4
524    DFT_4(data0, data1, data2, data3);
525
526    // Store four complex output values
527    vstore2(data0, 0, (__global DATA_TYPE *)output.ptr);
528    vstore2(data1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 1, 0));
529    vstore2(data2, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 2, 0));
530    vstore2(data3, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 3, 0));
531}
532
533/** Computes the first stage of a radix-5 DFT on axis 0.
534 *
535 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
536 *
537 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
538 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
539 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
540 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
541 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
542 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
543 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
544 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
545 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
546 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
547 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
548 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
549 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
550 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
551 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
552 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
553 */
554__kernel void fft_radix_5_first_stage_axis_0(
555    TENSOR3D_DECLARATION(input)
556#ifndef IN_PLACE
557    ,
558    TENSOR3D_DECLARATION(output)
559#endif /* not IN_PLACE */
560)
561{
562    // Get tensor pointers
563    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
564#ifdef IN_PLACE
565    Tensor3D output = input;
566#else  /* IN_PLACE */
567    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
568#endif /* IN_PLACE */
569
570    // Load five complex input values
571    VEC_DATA_TYPE(DATA_TYPE, 8)
572    data0 = vload8(0, (__global DATA_TYPE *)input.ptr);
573    VEC_DATA_TYPE(DATA_TYPE, 2)
574    data1 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 4, 0, 0));
575
576    // Compute DFT N = 5
577    DFT_5(data0.s01, data0.s23, data0.s45, data0.s67, data1.s01);
578
579    // Store five complex output values
580    vstore8(data0, 0, (__global DATA_TYPE *)output.ptr);
581    vstore2(data1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 4, 0, 0));
582}
583
584/** Computes the first stage of a radix-5 DFT on axis 1.
585 *
586 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
587 *
588 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
589 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
590 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
591 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
592 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
593 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
594 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
595 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
596 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
597 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
598 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
599 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
600 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
601 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
602 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
603 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
604 */
605__kernel void fft_radix_5_first_stage_axis_1(
606    TENSOR3D_DECLARATION(input)
607#ifndef IN_PLACE
608    ,
609    TENSOR3D_DECLARATION(output)
610#endif /* not IN_PLACE */
611)
612{
613    // Get tensor pointers
614    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
615#ifdef IN_PLACE
616    Tensor3D output = input;
617#else  /* IN_PLACE */
618    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
619#endif /* IN_PLACE */
620
621    // Load five complex input values
622    VEC_DATA_TYPE(DATA_TYPE, 2)
623    data0 = vload2(0, (__global DATA_TYPE *)input.ptr);
624    VEC_DATA_TYPE(DATA_TYPE, 2)
625    data1 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 1, 0));
626    VEC_DATA_TYPE(DATA_TYPE, 2)
627    data2 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 2, 0));
628    VEC_DATA_TYPE(DATA_TYPE, 2)
629    data3 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 3, 0));
630    VEC_DATA_TYPE(DATA_TYPE, 2)
631    data4 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 4, 0));
632
633    // Compute DFT N = 5
634    DFT_5(data0, data1, data2, data3, data4);
635
636    // Store five complex output values
637    vstore2(data0, 0, (__global DATA_TYPE *)output.ptr);
638    vstore2(data1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 1, 0));
639    vstore2(data2, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 2, 0));
640    vstore2(data3, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 3, 0));
641    vstore2(data4, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 4, 0));
642}
643
644/** Computes the first stage of a radix-7 DFT on axis 0.
645 *
646 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
647 *
648 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
649 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
650 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
651 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
652 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
653 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
654 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
655 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
656 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
657 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
658 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
659 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
660 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
661 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
662 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
663 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
664 */
665__kernel void fft_radix_7_first_stage_axis_0(
666    TENSOR3D_DECLARATION(input)
667#ifndef IN_PLACE
668    ,
669    TENSOR3D_DECLARATION(output)
670#endif /* not IN_PLACE */
671)
672{
673    // Get tensor pointers
674    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
675#ifdef IN_PLACE
676    Tensor3D output = input;
677#else  /* IN_PLACE */
678    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
679#endif /* IN_PLACE */
680
681    // Load seven complex input values
682    VEC_DATA_TYPE(DATA_TYPE, 8)
683    data0 = vload8(0, (__global DATA_TYPE *)input.ptr);
684    VEC_DATA_TYPE(DATA_TYPE, 4)
685    data1 = vload4(0, (__global DATA_TYPE *)tensor3D_offset(&input, 4, 0, 0));
686    VEC_DATA_TYPE(DATA_TYPE, 2)
687    data2 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 6, 0, 0));
688
689    // Compute DFT N = 7
690    DFT_7(data0.s01, data0.s23, data0.s45, data0.s67, data1.s01, data1.s23, data2.s01);
691
692    // Store seven complex output values
693    vstore8(data0, 0, (__global DATA_TYPE *)output.ptr);
694    vstore4(data1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 4, 0, 0));
695    vstore2(data2, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 6, 0, 0));
696}
697
698/** Computes the first stage of a radix-7 DFT on axis 1.
699 *
700 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
701 *
702 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
703 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
704 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
705 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
706 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
707 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
708 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
709 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
710 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
711 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
712 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
713 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
714 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
715 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
716 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
717 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
718 */
719__kernel void fft_radix_7_first_stage_axis_1(
720    TENSOR3D_DECLARATION(input)
721#ifndef IN_PLACE
722    ,
723    TENSOR3D_DECLARATION(output)
724#endif /* not IN_PLACE */
725)
726{
727    // Get tensor pointers
728    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
729#ifdef IN_PLACE
730    Tensor3D output = input;
731#else  /* IN_PLACE */
732    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
733#endif /* IN_PLACE */
734
735    // Load seven complex input values
736    VEC_DATA_TYPE(DATA_TYPE, 2)
737    data0 = vload2(0, (__global DATA_TYPE *)input.ptr);
738    VEC_DATA_TYPE(DATA_TYPE, 2)
739    data1 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 1, 0));
740    VEC_DATA_TYPE(DATA_TYPE, 2)
741    data2 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 2, 0));
742    VEC_DATA_TYPE(DATA_TYPE, 2)
743    data3 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 3, 0));
744    VEC_DATA_TYPE(DATA_TYPE, 2)
745    data4 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 4, 0));
746    VEC_DATA_TYPE(DATA_TYPE, 2)
747    data5 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 5, 0));
748    VEC_DATA_TYPE(DATA_TYPE, 2)
749    data6 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 6, 0));
750
751    // Compute DFT N = 7
752    DFT_7(data0, data1, data2, data3, data4, data5, data6);
753
754    // Store seven complex output values
755    vstore2(data0, 0, (__global DATA_TYPE *)output.ptr);
756    vstore2(data1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 1, 0));
757    vstore2(data2, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 2, 0));
758    vstore2(data3, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 3, 0));
759    vstore2(data4, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 4, 0));
760    vstore2(data5, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 5, 0));
761    vstore2(data6, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 6, 0));
762}
763
764/** Computes the first stage of a radix-8 DFT on axis 0.
765 *
766 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
767 *
768 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
769 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
770 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
771 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
772 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
773 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
774 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
775 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
776 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
777 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
778 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
779 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
780 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
781 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
782 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
783 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
784 */
785__kernel void fft_radix_8_first_stage_axis_0(
786    TENSOR3D_DECLARATION(input)
787#ifndef IN_PLACE
788    ,
789    TENSOR3D_DECLARATION(output)
790#endif /* not IN_PLACE */
791)
792{
793    // Get tensor pointers
794    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
795#ifdef IN_PLACE
796    Tensor3D output = input;
797#else  /* IN_PLACE */
798    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
799#endif /* IN_PLACE */
800
801    // Load eight complex input values
802    VEC_DATA_TYPE(DATA_TYPE, 16)
803    data = vload16(0, (__global DATA_TYPE *)input.ptr);
804
805    // Compute DFT N = 8
806    DFT_8(data.s01, data.s23, data.s45, data.s67, data.s89, data.sAB, data.sCD, data.sEF);
807
808    // Store eight complex output values
809    vstore16(data, 0, (__global DATA_TYPE *)output.ptr);
810}
811
812/** Computes the first stage of a radix-8 DFT on axis 1.
813 *
814 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
815 *
816 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
817 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
818 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
819 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
820 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
821 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
822 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
823 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
824 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
825 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
826 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
827 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
828 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
829 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
830 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
831 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
832 */
833__kernel void fft_radix_8_first_stage_axis_1(
834    TENSOR3D_DECLARATION(input)
835#ifndef IN_PLACE
836    ,
837    TENSOR3D_DECLARATION(output)
838#endif /* not IN_PLACE */
839)
840{
841    // Get tensor pointers
842    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
843#ifdef IN_PLACE
844    Tensor3D output = input;
845#else  /* IN_PLACE */
846    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
847#endif /* IN_PLACE */
848
849    // Load eight complex input values
850    VEC_DATA_TYPE(DATA_TYPE, 2)
851    data0 = vload2(0, (__global DATA_TYPE *)input.ptr);
852    VEC_DATA_TYPE(DATA_TYPE, 2)
853    data1 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 1, 0));
854    VEC_DATA_TYPE(DATA_TYPE, 2)
855    data2 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 2, 0));
856    VEC_DATA_TYPE(DATA_TYPE, 2)
857    data3 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 3, 0));
858    VEC_DATA_TYPE(DATA_TYPE, 2)
859    data4 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 4, 0));
860    VEC_DATA_TYPE(DATA_TYPE, 2)
861    data5 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 5, 0));
862    VEC_DATA_TYPE(DATA_TYPE, 2)
863    data6 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 6, 0));
864    VEC_DATA_TYPE(DATA_TYPE, 2)
865    data7 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 7, 0));
866
867    // Compute DFT N = 8
868    DFT_8(data0, data1, data2, data3, data4, data5, data6, data7);
869
870    // Store eight complex output values
871    vstore2(data0, 0, (__global DATA_TYPE *)output.ptr);
872    vstore2(data1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 1, 0));
873    vstore2(data2, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 2, 0));
874    vstore2(data3, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 3, 0));
875    vstore2(data4, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 4, 0));
876    vstore2(data5, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 5, 0));
877    vstore2(data6, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 6, 0));
878    vstore2(data7, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 7, 0));
879}
880
881/** Computes a stage of a radix-2 FFT on axis 0.
882 *
883 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
884 *
885 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
886 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
887 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
888 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
889 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
890 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
891 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
892 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
893 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
894 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
895 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
896 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
897 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
898 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
899 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
900 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
901 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
902 * @param[in]     Ni                                   Nx * Ny.
903 * @param[in]     exp_const                            Exponent constant
904 */
905__kernel void fft_radix_2_axis_0(
906    TENSOR3D_DECLARATION(input)
907#ifndef IN_PLACE
908    ,
909    TENSOR3D_DECLARATION(output)
910#endif /* not IN_PLACE */
911    ,
912    uint Nx, uint Ni, float exp_const)
913{
914    // Each work-item computes a single radix-2
915    uint kx = get_global_id(0);
916
917    // Compute nx
918    uint nx = kx % Nx;
919
920    // Compute n index
921    uint n = nx + (kx / Nx) * Ni;
922
923    // Get tensor pointers
924    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
925    input.ptr += n * input.stride_x + get_global_id(1) * input.stride_y + get_global_id(2) * input.stride_z;
926#ifdef IN_PLACE
927    Tensor3D output = input;
928#else  /* IN_PLACE */
929    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
930    output.ptr += n * output.stride_x + get_global_id(1) * output.stride_y + get_global_id(2) * output.stride_z;
931#endif /* IN_PLACE */
932
933    // Load two complex input values
934    VEC_DATA_TYPE(DATA_TYPE, 2)
935    c0 = vload2(0, (__global DATA_TYPE *)input.ptr);
936    VEC_DATA_TYPE(DATA_TYPE, 2)
937    c1 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, Nx, 0, 0));
938
939    // Compute phi
940    DATA_TYPE phi = (DATA_TYPE)nx * (DATA_TYPE)exp_const;
941
942    // Multiply by twiddle factor
943    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
944
945    // Compute DFT N = 2
946    DFT_2(c0, c1);
947
948    // Store two complex output values
949    vstore2(c0, 0, (__global DATA_TYPE *)output.ptr);
950    vstore2(c1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, Nx, 0, 0));
951}
952
953/** Computes a stage of a radix-2 FFT on axis 1.
954 *
955 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
956 *
957 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
958 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
959 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
960 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
961 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
962 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
963 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
964 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
965 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
966 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
967 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
968 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
969 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
970 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
971 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
972 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
973 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
974 * @param[in]     Ni                                   Nx * Ny.
975 * @param[in]     exp_const                            Exponent constant
976 */
977__kernel void fft_radix_2_axis_1(
978    TENSOR3D_DECLARATION(input)
979#ifndef IN_PLACE
980    ,
981    TENSOR3D_DECLARATION(output)
982#endif /* not IN_PLACE */
983    ,
984    uint Nx, uint Ni, float exp_const)
985{
986    // Each work-item computes a single radix-2
987    uint kx = get_global_id(1);
988
989    // Compute nx
990    uint nx = kx % Nx;
991
992    // Compute n index
993    uint n = nx + (kx / Nx) * Ni;
994
995    // Get tensor pointers
996    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
997    input.ptr += get_global_id(0) * input.stride_x + n * input.stride_y + get_global_id(2) * input.stride_z;
998#ifdef IN_PLACE
999    Tensor3D output = input;
1000#else  /* IN_PLACE */
1001    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1002    output.ptr += get_global_id(0) * output.stride_x + n * output.stride_y + get_global_id(2) * output.stride_z;
1003#endif /* IN_PLACE */
1004
1005    // Load two complex input values
1006    VEC_DATA_TYPE(DATA_TYPE, 2)
1007    c0 = vload2(0, (__global DATA_TYPE *)input.ptr);
1008    VEC_DATA_TYPE(DATA_TYPE, 2)
1009    c1 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, Nx, 0));
1010
1011    // Compute phi
1012    DATA_TYPE phi = (DATA_TYPE)nx * (DATA_TYPE)exp_const;
1013
1014    // Multiply by twiddle factor
1015    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1016
1017    // Compute DFT N = 2
1018    DFT_2(c0, c1);
1019
1020    // Store two complex output values
1021    vstore2(c0, 0, (__global DATA_TYPE *)output.ptr);
1022    vstore2(c1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, Nx, 0));
1023}
1024
1025/** Computes a stage of a radix-3 FFT on axis 0.
1026 *
1027 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1028 *
1029 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
1030 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1031 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1032 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1033 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1034 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1035 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1036 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1037 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
1038 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
1039 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1040 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
1041 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1042 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1043 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1044 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
1045 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
1046 * @param[in]     Ni                                   Nx * Ny.
1047 * @param[in]     exp_const                            Exponent constant
1048 */
1049__kernel void fft_radix_3_axis_0(
1050    TENSOR3D_DECLARATION(input)
1051#ifndef IN_PLACE
1052    ,
1053    TENSOR3D_DECLARATION(output)
1054#endif /* not IN_PLACE */
1055    ,
1056    uint Nx, uint Ni, float exp_const)
1057{
1058    // Each work-item computes a single radix-3
1059    uint kx = get_global_id(0);
1060
1061    // Compute nx
1062    uint nx = kx % Nx;
1063
1064    // Compute n index
1065    uint n = nx + (kx / Nx) * Ni;
1066
1067    // Get tensor pointers
1068    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1069    input.ptr += n * input.stride_x + get_global_id(1) * input.stride_y + get_global_id(2) * input.stride_z;
1070#ifdef IN_PLACE
1071    Tensor3D output = input;
1072#else  /* IN_PLACE */
1073    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1074    output.ptr += n * output.stride_x + get_global_id(1) * output.stride_y + get_global_id(2) * output.stride_z;
1075#endif /* IN_PLACE */
1076
1077    // Load three complex input values
1078    VEC_DATA_TYPE(DATA_TYPE, 2)
1079    c0 = vload2(0, (__global DATA_TYPE *)input.ptr);
1080    VEC_DATA_TYPE(DATA_TYPE, 2)
1081    c1 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, Nx, 0, 0));
1082    VEC_DATA_TYPE(DATA_TYPE, 2)
1083    c2 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 2 * Nx, 0, 0));
1084
1085    // Compute phi
1086    DATA_TYPE phi = (DATA_TYPE)nx * (DATA_TYPE)exp_const;
1087
1088    // Multiply by twiddle factor
1089    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1090    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1091
1092    // Compute DFT N = 3
1093    DFT_3(c0, c1, c2);
1094
1095    // Store three complex output values
1096    vstore2(c0, 0, (__global DATA_TYPE *)output.ptr);
1097    vstore2(c1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, Nx, 0, 0));
1098    vstore2(c2, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 2 * Nx, 0, 0));
1099}
1100
1101/** Computes a stage of a radix-3 FFT on axis 1.
1102 *
1103 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1104 *
1105 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
1106 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1107 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1108 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1109 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1110 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1111 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1112 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1113 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
1114 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
1115 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1116 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
1117 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1118 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1119 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1120 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
1121 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
1122 * @param[in]     Ni                                   Nx * Ny.
1123 * @param[in]     exp_const                            Exponent constant
1124 */
1125__kernel void fft_radix_3_axis_1(
1126    TENSOR3D_DECLARATION(input)
1127#ifndef IN_PLACE
1128    ,
1129    TENSOR3D_DECLARATION(output)
1130#endif /* not IN_PLACE */
1131    ,
1132    uint Nx, uint Ni, float exp_const)
1133{
1134    // Each work-item computes a single radix-3
1135    uint kx = get_global_id(1);
1136
1137    // Compute nx
1138    uint nx = kx % Nx;
1139
1140    // Compute n index
1141    uint n = nx + (kx / Nx) * Ni;
1142
1143    // Get tensor pointers
1144    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1145    input.ptr += get_global_id(0) * input.stride_x + n * input.stride_y + get_global_id(2) * input.stride_z;
1146#ifdef IN_PLACE
1147    Tensor3D output = input;
1148#else  /* IN_PLACE */
1149    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1150    output.ptr += get_global_id(0) * output.stride_x + n * output.stride_y + get_global_id(2) * output.stride_z;
1151#endif /* IN_PLACE */
1152
1153    // Load three complex input values
1154    VEC_DATA_TYPE(DATA_TYPE, 2)
1155    c0 = vload2(0, (__global DATA_TYPE *)input.ptr);
1156    VEC_DATA_TYPE(DATA_TYPE, 2)
1157    c1 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, Nx, 0));
1158    VEC_DATA_TYPE(DATA_TYPE, 2)
1159    c2 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 2 * Nx, 0));
1160
1161    // Compute phi
1162    DATA_TYPE phi = (DATA_TYPE)nx * (DATA_TYPE)exp_const;
1163
1164    // Multiply by twiddle factor
1165    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1166    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1167
1168    // Compute DFT N = 3
1169    DFT_3(c0, c1, c2);
1170
1171    // Store three complex output values
1172    vstore2(c0, 0, (__global DATA_TYPE *)output.ptr);
1173    vstore2(c1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, Nx, 0));
1174    vstore2(c2, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 2 * Nx, 0));
1175}
1176
1177/** Computes a stage of a radix-4 FFT on axis 0.
1178 *
1179 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1180 *
1181 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
1182 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1183 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1184 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1185 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1186 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1187 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1188 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1189 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
1190 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
1191 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1192 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
1193 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1194 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1195 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1196 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
1197 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
1198 * @param[in]     Ni                                   Nx * Ny.
1199 * @param[in]     exp_const                            Exponent constant
1200 */
1201__kernel void fft_radix_4_axis_0(
1202    TENSOR3D_DECLARATION(input)
1203#ifndef IN_PLACE
1204    ,
1205    TENSOR3D_DECLARATION(output)
1206#endif /* not IN_PLACE */
1207    ,
1208    uint Nx, uint Ni, float exp_const)
1209{
1210    // Each work-item computes a single radix-4
1211    uint kx = get_global_id(0);
1212
1213    // Compute nx
1214    uint nx = kx % Nx;
1215
1216    // Compute n index
1217    uint n = nx + (kx / Nx) * Ni;
1218
1219    // Get tensor pointers
1220    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1221    input.ptr += n * input.stride_x + get_global_id(1) * input.stride_y + get_global_id(2) * input.stride_z;
1222#ifdef IN_PLACE
1223    Tensor3D output = input;
1224#else  /* IN_PLACE */
1225    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1226    output.ptr += n * output.stride_x + get_global_id(1) * output.stride_y + get_global_id(2) * output.stride_z;
1227#endif /* IN_PLACE */
1228
1229    // Load four complex input values
1230    VEC_DATA_TYPE(DATA_TYPE, 2)
1231    c0 = vload2(0, (__global DATA_TYPE *)input.ptr);
1232    VEC_DATA_TYPE(DATA_TYPE, 2)
1233    c1 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, Nx, 0, 0));
1234    VEC_DATA_TYPE(DATA_TYPE, 2)
1235    c2 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 2 * Nx, 0, 0));
1236    VEC_DATA_TYPE(DATA_TYPE, 2)
1237    c3 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 3 * Nx, 0, 0));
1238
1239    // Compute phi
1240    DATA_TYPE phi = (DATA_TYPE)nx * (DATA_TYPE)exp_const;
1241
1242    // Multiply by twiddle factor
1243    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1244    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1245    TWIDDLE_FACTOR_MULTIPLICATION(3 * phi, c3);
1246
1247    // Compute DFT N = 4
1248    DFT_4(c0, c1, c2, c3);
1249
1250    // Store four complex output values
1251    vstore2(c0, 0, (__global DATA_TYPE *)output.ptr);
1252    vstore2(c1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, Nx, 0, 0));
1253    vstore2(c2, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 2 * Nx, 0, 0));
1254    vstore2(c3, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 3 * Nx, 0, 0));
1255}
1256
1257/** Computes a stage of a radix-4 FFT on axis 1.
1258 *
1259 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1260 *
1261 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
1262 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1263 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1264 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1265 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1266 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1267 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1268 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1269 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
1270 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
1271 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1272 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
1273 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1274 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1275 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1276 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
1277 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
1278 * @param[in]     Ni                                   Nx * Ny.
1279 * @param[in]     exp_const                            Exponent constant
1280 */
1281__kernel void fft_radix_4_axis_1(
1282    TENSOR3D_DECLARATION(input)
1283#ifndef IN_PLACE
1284    ,
1285    TENSOR3D_DECLARATION(output)
1286#endif /* not IN_PLACE */
1287    ,
1288    uint Nx, uint Ni, float exp_const)
1289{
1290    // Each work-item computes a single radix-4
1291    uint kx = get_global_id(1);
1292
1293    // Compute nx
1294    uint nx = kx % Nx;
1295
1296    // Compute n index
1297    uint n = nx + (kx / Nx) * Ni;
1298
1299    // Get tensor pointers
1300    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1301    input.ptr += get_global_id(0) * input.stride_x + n * input.stride_y + get_global_id(2) * input.stride_z;
1302#ifdef IN_PLACE
1303    Tensor3D output = input;
1304#else  /* IN_PLACE */
1305    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1306    output.ptr += get_global_id(0) * output.stride_x + n * output.stride_y + get_global_id(2) * output.stride_z;
1307#endif /* IN_PLACE */
1308
1309    // Load four complex input values
1310    VEC_DATA_TYPE(DATA_TYPE, 2)
1311    c0 = vload2(0, (__global DATA_TYPE *)input.ptr);
1312    VEC_DATA_TYPE(DATA_TYPE, 2)
1313    c1 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, Nx, 0));
1314    VEC_DATA_TYPE(DATA_TYPE, 2)
1315    c2 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 2 * Nx, 0));
1316    VEC_DATA_TYPE(DATA_TYPE, 2)
1317    c3 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 3 * Nx, 0));
1318
1319    // Compute phi
1320    DATA_TYPE phi = (DATA_TYPE)nx * (DATA_TYPE)exp_const;
1321
1322    // Multiply by twiddle factor
1323    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1324    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1325    TWIDDLE_FACTOR_MULTIPLICATION(3 * phi, c3);
1326
1327    // Compute DFT N = 4
1328    DFT_4(c0, c1, c2, c3);
1329
1330    // Store four complex output values
1331    vstore2(c0, 0, (__global DATA_TYPE *)output.ptr);
1332    vstore2(c1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, Nx, 0));
1333    vstore2(c2, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 2 * Nx, 0));
1334    vstore2(c3, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 3 * Nx, 0));
1335}
1336
1337/** Computes a stage of a radix-5 FFT on axis 0.
1338 *
1339 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1340 *
1341 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
1342 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1343 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1344 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1345 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1346 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1347 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1348 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1349 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
1350 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
1351 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1352 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
1353 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1354 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1355 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1356 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
1357 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
1358 * @param[in]     Ni                                   Nx * Ny.
1359 * @param[in]     exp_const                            Exponent constant
1360 */
1361__kernel void fft_radix_5_axis_0(
1362    TENSOR3D_DECLARATION(input)
1363#ifndef IN_PLACE
1364    ,
1365    TENSOR3D_DECLARATION(output)
1366#endif /* not IN_PLACE */
1367    ,
1368    uint Nx, uint Ni, float exp_const)
1369{
1370    // Each work-item computes a single radix-5
1371    uint kx = get_global_id(0);
1372
1373    // Compute nx
1374    uint nx = kx % Nx;
1375
1376    // Compute n index
1377    uint n = nx + (kx / Nx) * Ni;
1378
1379    // Get tensor pointers
1380    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1381    input.ptr += n * input.stride_x + get_global_id(1) * input.stride_y + get_global_id(2) * input.stride_z;
1382#ifdef IN_PLACE
1383    Tensor3D output = input;
1384#else  /* IN_PLACE */
1385    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1386    output.ptr += n * output.stride_x + get_global_id(1) * output.stride_y + get_global_id(2) * output.stride_z;
1387#endif /* IN_PLACE */
1388
1389    // Load five complex input values
1390    VEC_DATA_TYPE(DATA_TYPE, 2)
1391    c0 = vload2(0, (__global DATA_TYPE *)input.ptr);
1392    VEC_DATA_TYPE(DATA_TYPE, 2)
1393    c1 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, Nx, 0, 0));
1394    VEC_DATA_TYPE(DATA_TYPE, 2)
1395    c2 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 2 * Nx, 0, 0));
1396    VEC_DATA_TYPE(DATA_TYPE, 2)
1397    c3 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 3 * Nx, 0, 0));
1398    VEC_DATA_TYPE(DATA_TYPE, 2)
1399    c4 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 4 * Nx, 0, 0));
1400
1401    // Compute phi
1402    DATA_TYPE phi = (DATA_TYPE)nx * (DATA_TYPE)exp_const;
1403
1404    // Multiply by twiddle factor
1405    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1406    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1407    TWIDDLE_FACTOR_MULTIPLICATION(3 * phi, c3);
1408    TWIDDLE_FACTOR_MULTIPLICATION(4 * phi, c4);
1409
1410    // Compute DFT N = 5
1411    DFT_5(c0, c1, c2, c3, c4);
1412
1413    // Store five complex output values
1414    vstore2(c0, 0, (__global DATA_TYPE *)output.ptr);
1415    vstore2(c1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, Nx, 0, 0));
1416    vstore2(c2, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 2 * Nx, 0, 0));
1417    vstore2(c3, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 3 * Nx, 0, 0));
1418    vstore2(c4, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 4 * Nx, 0, 0));
1419}
1420
1421/** Computes a stage of a radix-5 FFT on axis 1.
1422 *
1423 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1424 *
1425 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
1426 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1427 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1428 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1429 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1430 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1431 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1432 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1433 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
1434 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
1435 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1436 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
1437 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1438 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1439 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1440 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
1441 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
1442 * @param[in]     Ni                                   Nx * Ny.
1443 * @param[in]     exp_const                            Exponent constant
1444 */
1445__kernel void fft_radix_5_axis_1(
1446    TENSOR3D_DECLARATION(input)
1447#ifndef IN_PLACE
1448    ,
1449    TENSOR3D_DECLARATION(output)
1450#endif /* not IN_PLACE */
1451    ,
1452    uint Nx, uint Ni, float exp_const)
1453{
1454    // Each work-item computes a single radix-5
1455    uint kx = get_global_id(1);
1456
1457    // Compute nx
1458    uint nx = kx % Nx;
1459
1460    // Compute n index
1461    uint n = nx + (kx / Nx) * Ni;
1462
1463    // Get tensor pointers
1464    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1465    input.ptr += get_global_id(0) * input.stride_x + n * input.stride_y + get_global_id(2) * input.stride_z;
1466#ifdef IN_PLACE
1467    Tensor3D output = input;
1468#else  /* IN_PLACE */
1469    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1470    output.ptr += get_global_id(0) * output.stride_x + n * output.stride_y + get_global_id(2) * output.stride_z;
1471#endif /* IN_PLACE */
1472
1473    // Load five complex input values
1474    VEC_DATA_TYPE(DATA_TYPE, 2)
1475    c0 = vload2(0, (__global DATA_TYPE *)input.ptr);
1476    VEC_DATA_TYPE(DATA_TYPE, 2)
1477    c1 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, Nx, 0));
1478    VEC_DATA_TYPE(DATA_TYPE, 2)
1479    c2 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 2 * Nx, 0));
1480    VEC_DATA_TYPE(DATA_TYPE, 2)
1481    c3 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 3 * Nx, 0));
1482    VEC_DATA_TYPE(DATA_TYPE, 2)
1483    c4 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 4 * Nx, 0));
1484
1485    // Compute phi
1486    DATA_TYPE phi = (DATA_TYPE)nx * (DATA_TYPE)exp_const;
1487
1488    // Multiply by twiddle factor
1489    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1490    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1491    TWIDDLE_FACTOR_MULTIPLICATION(3 * phi, c3);
1492    TWIDDLE_FACTOR_MULTIPLICATION(4 * phi, c4);
1493
1494    // Compute DFT N = 5
1495    DFT_5(c0, c1, c2, c3, c4);
1496
1497    // Store five complex output values
1498    vstore2(c0, 0, (__global DATA_TYPE *)output.ptr);
1499    vstore2(c1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, Nx, 0));
1500    vstore2(c2, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 2 * Nx, 0));
1501    vstore2(c3, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 3 * Nx, 0));
1502    vstore2(c4, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 4 * Nx, 0));
1503}
1504
1505/** Computes a stage of a radix-7 FFT on axis 0.
1506 *
1507 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1508 *
1509 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
1510 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1511 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1512 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1513 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1514 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1515 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1516 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1517 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
1518 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
1519 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1520 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
1521 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1522 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1523 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1524 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
1525 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
1526 * @param[in]     Ni                                   Nx * Ny.
1527 * @param[in]     exp_const                            Exponent constant
1528 */
1529__kernel void fft_radix_7_axis_0(
1530    TENSOR3D_DECLARATION(input)
1531#ifndef IN_PLACE
1532    ,
1533    TENSOR3D_DECLARATION(output)
1534#endif /* not IN_PLACE */
1535    ,
1536    uint Nx, uint Ni, float exp_const)
1537{
1538    // Each work-item computes a single radix-7
1539    uint kx = get_global_id(0);
1540
1541    // Compute nx
1542    uint nx = kx % Nx;
1543
1544    // Compute n index
1545    uint n = nx + (kx / Nx) * Ni;
1546
1547    // Get tensor pointers
1548    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1549    input.ptr += n * input.stride_x + get_global_id(1) * input.stride_y + get_global_id(2) * input.stride_z;
1550#ifdef IN_PLACE
1551    Tensor3D output = input;
1552#else  /* IN_PLACE */
1553    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1554    output.ptr += n * output.stride_x + get_global_id(1) * output.stride_y + get_global_id(2) * output.stride_z;
1555#endif /* IN_PLACE */
1556
1557    // Load seven complex input values
1558    VEC_DATA_TYPE(DATA_TYPE, 2)
1559    c0 = vload2(0, (__global DATA_TYPE *)input.ptr);
1560    VEC_DATA_TYPE(DATA_TYPE, 2)
1561    c1 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, Nx, 0, 0));
1562    VEC_DATA_TYPE(DATA_TYPE, 2)
1563    c2 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 2 * Nx, 0, 0));
1564    VEC_DATA_TYPE(DATA_TYPE, 2)
1565    c3 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 3 * Nx, 0, 0));
1566    VEC_DATA_TYPE(DATA_TYPE, 2)
1567    c4 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 4 * Nx, 0, 0));
1568    VEC_DATA_TYPE(DATA_TYPE, 2)
1569    c5 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 5 * Nx, 0, 0));
1570    VEC_DATA_TYPE(DATA_TYPE, 2)
1571    c6 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 6 * Nx, 0, 0));
1572
1573    // Compute phi
1574    DATA_TYPE phi = (DATA_TYPE)nx * (DATA_TYPE)exp_const;
1575
1576    // Multiply by twiddle factor
1577    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1578    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1579    TWIDDLE_FACTOR_MULTIPLICATION(3 * phi, c3);
1580    TWIDDLE_FACTOR_MULTIPLICATION(4 * phi, c4);
1581    TWIDDLE_FACTOR_MULTIPLICATION(5 * phi, c5);
1582    TWIDDLE_FACTOR_MULTIPLICATION(6 * phi, c6);
1583
1584    // Compute DFT N = 7
1585    DFT_7(c0, c1, c2, c3, c4, c5, c6);
1586
1587    // Store seven complex output values
1588    vstore2(c0, 0, (__global DATA_TYPE *)output.ptr);
1589    vstore2(c1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, Nx, 0, 0));
1590    vstore2(c2, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 2 * Nx, 0, 0));
1591    vstore2(c3, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 3 * Nx, 0, 0));
1592    vstore2(c4, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 4 * Nx, 0, 0));
1593    vstore2(c5, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 5 * Nx, 0, 0));
1594    vstore2(c6, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 6 * Nx, 0, 0));
1595}
1596
1597/** Computes a stage of a radix-7 FFT on axis 1.
1598 *
1599 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1600 *
1601 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
1602 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1603 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1604 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1605 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1606 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1607 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1608 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1609 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
1610 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
1611 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1612 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
1613 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1614 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1615 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1616 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
1617 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
1618 * @param[in]     Ni                                   Nx * Ny.
1619 * @param[in]     exp_const                            Exponent constant
1620 */
1621__kernel void fft_radix_7_axis_1(
1622    TENSOR3D_DECLARATION(input)
1623#ifndef IN_PLACE
1624    ,
1625    TENSOR3D_DECLARATION(output)
1626#endif /* not IN_PLACE */
1627    ,
1628    uint Nx, uint Ni, float exp_const)
1629{
1630    // Each work-item computes a single radix-7
1631    uint kx = get_global_id(1);
1632
1633    // Compute nx
1634    uint nx = kx % Nx;
1635
1636    // Compute n index
1637    uint n = nx + (kx / Nx) * Ni;
1638
1639    // Get tensor pointers
1640    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1641    input.ptr += get_global_id(0) * input.stride_x + n * input.stride_y + get_global_id(2) * input.stride_z;
1642#ifdef IN_PLACE
1643    Tensor3D output = input;
1644#else  /* IN_PLACE */
1645    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1646    output.ptr += get_global_id(0) * output.stride_x + n * output.stride_y + get_global_id(2) * output.stride_z;
1647#endif /* IN_PLACE */
1648
1649    // Load seven complex input values
1650    VEC_DATA_TYPE(DATA_TYPE, 2)
1651    c0 = vload2(0, (__global DATA_TYPE *)input.ptr);
1652    VEC_DATA_TYPE(DATA_TYPE, 2)
1653    c1 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, Nx, 0));
1654    VEC_DATA_TYPE(DATA_TYPE, 2)
1655    c2 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 2 * Nx, 0));
1656    VEC_DATA_TYPE(DATA_TYPE, 2)
1657    c3 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 3 * Nx, 0));
1658    VEC_DATA_TYPE(DATA_TYPE, 2)
1659    c4 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 4 * Nx, 0));
1660    VEC_DATA_TYPE(DATA_TYPE, 2)
1661    c5 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 5 * Nx, 0));
1662    VEC_DATA_TYPE(DATA_TYPE, 2)
1663    c6 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 6 * Nx, 0));
1664
1665    // Compute phi
1666    DATA_TYPE phi = (DATA_TYPE)nx * (DATA_TYPE)exp_const;
1667
1668    // Multiply by twiddle factor
1669    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1670    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1671    TWIDDLE_FACTOR_MULTIPLICATION(3 * phi, c3);
1672    TWIDDLE_FACTOR_MULTIPLICATION(4 * phi, c4);
1673    TWIDDLE_FACTOR_MULTIPLICATION(5 * phi, c5);
1674    TWIDDLE_FACTOR_MULTIPLICATION(6 * phi, c6);
1675
1676    // Compute DFT N = 7
1677    DFT_7(c0, c1, c2, c3, c4, c5, c6);
1678
1679    // Store seven complex output values
1680    vstore2(c0, 0, (__global DATA_TYPE *)output.ptr);
1681    vstore2(c1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, Nx, 0));
1682    vstore2(c2, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 2 * Nx, 0));
1683    vstore2(c3, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 3 * Nx, 0));
1684    vstore2(c4, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 4 * Nx, 0));
1685    vstore2(c5, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 5 * Nx, 0));
1686    vstore2(c6, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 6 * Nx, 0));
1687}
1688
1689/** Computes a stage of a radix-8 FFT on axis 0.
1690 *
1691 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1692 *
1693 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
1694 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1695 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1696 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1697 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1698 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1699 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1700 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1701 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
1702 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
1703 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1704 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
1705 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1706 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1707 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1708 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
1709 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
1710 * @param[in]     Ni                                   Nx * Ny.
1711 * @param[in]     exp_const                            Exponent constant
1712 */
1713__kernel void fft_radix_8_axis_0(
1714    TENSOR3D_DECLARATION(input)
1715#ifndef IN_PLACE
1716    ,
1717    TENSOR3D_DECLARATION(output)
1718#endif /* not IN_PLACE */
1719    ,
1720    uint Nx, uint Ni, float exp_const)
1721{
1722    // Each work-item computes a single radix-8
1723    uint kx = get_global_id(0);
1724
1725    // Compute nx
1726    uint nx = kx % Nx;
1727
1728    // Compute n index
1729    uint n = nx + (kx / Nx) * Ni;
1730
1731    // Get tensor pointers
1732    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1733    input.ptr += n * input.stride_x + get_global_id(1) * input.stride_y + get_global_id(2) * input.stride_z;
1734#ifdef IN_PLACE
1735    Tensor3D output = input;
1736#else  /* IN_PLACE */
1737    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1738    output.ptr += n * output.stride_x + get_global_id(1) * output.stride_y + get_global_id(2) * output.stride_z;
1739#endif /* IN_PLACE */
1740
1741    // Load eight complex input values
1742    VEC_DATA_TYPE(DATA_TYPE, 2)
1743    c0 = vload2(0, (__global DATA_TYPE *)input.ptr);
1744    VEC_DATA_TYPE(DATA_TYPE, 2)
1745    c1 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, Nx, 0, 0));
1746    VEC_DATA_TYPE(DATA_TYPE, 2)
1747    c2 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 2 * Nx, 0, 0));
1748    VEC_DATA_TYPE(DATA_TYPE, 2)
1749    c3 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 3 * Nx, 0, 0));
1750    VEC_DATA_TYPE(DATA_TYPE, 2)
1751    c4 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 4 * Nx, 0, 0));
1752    VEC_DATA_TYPE(DATA_TYPE, 2)
1753    c5 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 5 * Nx, 0, 0));
1754    VEC_DATA_TYPE(DATA_TYPE, 2)
1755    c6 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 6 * Nx, 0, 0));
1756    VEC_DATA_TYPE(DATA_TYPE, 2)
1757    c7 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 7 * Nx, 0, 0));
1758
1759    // Compute phi
1760    DATA_TYPE phi = (DATA_TYPE)nx * (DATA_TYPE)exp_const;
1761
1762    // Multiply by twiddle factor
1763    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1764    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1765    TWIDDLE_FACTOR_MULTIPLICATION(3 * phi, c3);
1766    TWIDDLE_FACTOR_MULTIPLICATION(4 * phi, c4);
1767    TWIDDLE_FACTOR_MULTIPLICATION(5 * phi, c5);
1768    TWIDDLE_FACTOR_MULTIPLICATION(6 * phi, c6);
1769    TWIDDLE_FACTOR_MULTIPLICATION(7 * phi, c7);
1770
1771    // Compute DFT N = 8
1772    DFT_8(c0, c1, c2, c3, c4, c5, c6, c7);
1773
1774    // Store eight complex output values
1775    vstore2(c0, 0, (__global DATA_TYPE *)output.ptr);
1776    vstore2(c1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, Nx, 0, 0));
1777    vstore2(c2, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 2 * Nx, 0, 0));
1778    vstore2(c3, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 3 * Nx, 0, 0));
1779    vstore2(c4, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 4 * Nx, 0, 0));
1780    vstore2(c5, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 5 * Nx, 0, 0));
1781    vstore2(c6, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 6 * Nx, 0, 0));
1782    vstore2(c7, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 7 * Nx, 0, 0));
1783}
1784
1785/** Computes a stage of a radix-8 FFT on axis 1.
1786 *
1787 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1788 *
1789 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F16/f32
1790 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1791 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1792 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1793 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1794 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1795 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1796 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1797 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
1798 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
1799 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1800 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
1801 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1802 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1803 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1804 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
1805 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
1806 * @param[in]     Ni                                   Nx * Ny.
1807 * @param[in]     exp_const                            Exponent constant
1808 */
1809__kernel void fft_radix_8_axis_1(
1810    TENSOR3D_DECLARATION(input)
1811#ifndef IN_PLACE
1812    ,
1813    TENSOR3D_DECLARATION(output)
1814#endif /* not IN_PLACE */
1815    ,
1816    uint Nx, uint Ni, float exp_const)
1817{
1818    // Each work-item computes a single radix-8
1819    uint kx = get_global_id(1);
1820
1821    // Compute nx
1822    uint nx = kx % Nx;
1823
1824    // Compute n index
1825    uint n = nx + (kx / Nx) * Ni;
1826
1827    // Get tensor pointers
1828    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1829    input.ptr += get_global_id(0) * input.stride_x + n * input.stride_y + get_global_id(2) * input.stride_z;
1830#ifdef IN_PLACE
1831    Tensor3D output = input;
1832#else  /* IN_PLACE */
1833    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1834    output.ptr += get_global_id(0) * output.stride_x + n * output.stride_y + get_global_id(2) * output.stride_z;
1835#endif /* IN_PLACE */
1836
1837    // Load eight complex input values
1838    VEC_DATA_TYPE(DATA_TYPE, 2)
1839    c0 = vload2(0, (__global DATA_TYPE *)input.ptr);
1840    VEC_DATA_TYPE(DATA_TYPE, 2)
1841    c1 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, Nx, 0));
1842    VEC_DATA_TYPE(DATA_TYPE, 2)
1843    c2 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 2 * Nx, 0));
1844    VEC_DATA_TYPE(DATA_TYPE, 2)
1845    c3 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 3 * Nx, 0));
1846    VEC_DATA_TYPE(DATA_TYPE, 2)
1847    c4 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 4 * Nx, 0));
1848    VEC_DATA_TYPE(DATA_TYPE, 2)
1849    c5 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 5 * Nx, 0));
1850    VEC_DATA_TYPE(DATA_TYPE, 2)
1851    c6 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 6 * Nx, 0));
1852    VEC_DATA_TYPE(DATA_TYPE, 2)
1853    c7 = vload2(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 7 * Nx, 0));
1854
1855    // Compute phi
1856    DATA_TYPE phi = (DATA_TYPE)nx * (DATA_TYPE)exp_const;
1857
1858    // Multiply by twiddle factor
1859    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1860    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1861    TWIDDLE_FACTOR_MULTIPLICATION(3 * phi, c3);
1862    TWIDDLE_FACTOR_MULTIPLICATION(4 * phi, c4);
1863    TWIDDLE_FACTOR_MULTIPLICATION(5 * phi, c5);
1864    TWIDDLE_FACTOR_MULTIPLICATION(6 * phi, c6);
1865    TWIDDLE_FACTOR_MULTIPLICATION(7 * phi, c7);
1866
1867    // Compute DFT N = 8
1868    DFT_8(c0, c1, c2, c3, c4, c5, c6, c7);
1869
1870    // Store eight complex output values
1871    vstore2(c0, 0, (__global DATA_TYPE *)output.ptr);
1872    vstore2(c1, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, Nx, 0));
1873    vstore2(c2, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 2 * Nx, 0));
1874    vstore2(c3, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 3 * Nx, 0));
1875    vstore2(c4, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 4 * Nx, 0));
1876    vstore2(c5, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 5 * Nx, 0));
1877    vstore2(c6, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 6 * Nx, 0));
1878    vstore2(c7, 0, (__global DATA_TYPE *)tensor3D_offset(&output, 0, 7 * Nx, 0));
1879}
1880#endif // defined(DATA_TYPE)