xref: /aosp_15_r20/external/ComputeLibrary/src/core/CL/cl_kernels/nchw/winograd_output_transform.cl (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1/*
2 * Copyright (c) 2018-2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "activation_float_helpers.h"
25#include "helpers.h"
26#include "tile_helpers.h"
27
28#if defined(NUM_TILES_X) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
29#if defined(VEC_SIZE) && VEC_SIZE == 2
30/** This OpenCL kernel performs Winograd output transform when the output tile is 2x2/2x1 or 1x2, the filter size 3x3/3x1 or 1x3 and the data layout is NCHW
31 *
32 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
33 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
34 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
35 * @note If this kernel is used to perform Winograd output transform 3x1, -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
36 * @note If this kernel is used to perform Winograd output transform 1x3, -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
37 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
38 * @note It is possible to select the activation function to apply using -DACTIVATION_TYPE e.g. -DACTIVATION_TYPE=relu
39 * @note A, B variables required by some activation functions are set using -DA_VAL= and -DB_VAL= respectively.
40 * @note Vector size should be given as a preprocessor argument using -DVEC_SIZE=size. Accepted values are -DVEC_SIZE=2 (for output_tile_size 2x2, 2x1, 1x2) and -DVEC_SIZE=4 (for output_tile_size 4x4, 4x1, 1x4)
41 *
42 * @param[in]  src_ptr                           Pointer to the source tensor. Supported data types: F32/F16
43 * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
44 * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
45 * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
46 * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
47 * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
48 * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
49 * @param[in]  src_stride_w                      Stride of the source tensor in W dimension (in bytes)
50 * @param[in]  src_step_w                        src_stride_w * number of elements along W processed per workitem(in bytes)
51 * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
52 * @param[out] dst_ptr                           Pointer to the destination tensor. Supported data types: same as @p src_ptr
53 * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
54 * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
55 * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
56 * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
57 * @param[in]  dst_stride_z                      Stride of the source tensor in Z dimension (in bytes)
58 * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
59 * @param[in]  dst_stride_w                      Stride of the source tensor in W dimension (in bytes)
60 * @param[in]  dst_step_w                        dst_stride_w * number of elements along W processed per workitem(in bytes)
61 * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
62 */
63__kernel void winograd_output_transform_2x2_3x3_nchw(
64    TENSOR4D_DECLARATION(src),
65    TENSOR4D_DECLARATION(dst)
66#if defined(HAS_BIAS)
67    ,
68    VECTOR_DECLARATION(bias)
69#endif // defined(HAS_BIAS)
70)
71{
72    // Each thread stores a 2x2/2x1 or 1x2 tile accordingly with the filter size
73#if defined(SRC_DEPTH)
74    Tensor4D       src             = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DEPTH);
75    const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
76#else  /* defined(SRC_DEPTH) */
77    Tensor3D       src             = CONVERT_TO_TENSOR3D_STRUCT(src);
78    const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
79#endif /* defined(SRC_DEPTH) */
80
81    // Load the values across the 16 or 4 channels to compose the 4x4 or 4x1 tile
82    DATA_TYPE d00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z));
83    DATA_TYPE d01 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z));
84    DATA_TYPE d02 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z));
85    DATA_TYPE d03 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z));
86
87#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
88    // Compute the 2x1 or 1x2 output tile
89    // out00 = d00 + d01 + d02
90    // out01 = d01 - d02 - d03
91
92    float out00 = d00 + d01 + d02;
93    float out01 = d01 - d02 - d03;
94#else  // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
95
96    DATA_TYPE d10 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z));
97    DATA_TYPE d11 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z));
98    DATA_TYPE d12 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z));
99    DATA_TYPE d13 = *((__global DATA_TYPE *)(src_addr + 7 * src_stride_z));
100
101    DATA_TYPE d20 = *((__global DATA_TYPE *)(src_addr + 8 * src_stride_z));
102    DATA_TYPE d21 = *((__global DATA_TYPE *)(src_addr + 9 * src_stride_z));
103    DATA_TYPE d22 = *((__global DATA_TYPE *)(src_addr + 10 * src_stride_z));
104    DATA_TYPE d23 = *((__global DATA_TYPE *)(src_addr + 11 * src_stride_z));
105
106    DATA_TYPE d30 = *((__global DATA_TYPE *)(src_addr + 12 * src_stride_z));
107    DATA_TYPE d31 = *((__global DATA_TYPE *)(src_addr + 13 * src_stride_z));
108    DATA_TYPE d32 = *((__global DATA_TYPE *)(src_addr + 14 * src_stride_z));
109    DATA_TYPE d33 = *((__global DATA_TYPE *)(src_addr + 15 * src_stride_z));
110
111    // Compute the 2x2 output tile
112    float k0 = d01 + d11 + d21;
113    float k1 = d02 + d12 + d22;
114    float k2 = d11 - d21 - d31;
115    float k3 = d12 - d22 - d32;
116
117    // out00 = d00 + d10 + d20 + d01 + d11 + d21 + d02 + d12 + d22
118    // out01 = d01 + d11 + d21 - (d02 + d12 + d22) - (d03 + d13 + d23)
119    // out10 = d10 - d20 - d30 + (d11 - d21 - d31) + (d12 - d22 - d32)
120    // out11 = d11 - d21 - d31 - (d12 - d22 - d32) - (d13 - d23 - d33)
121
122    float out00 = d10;
123    float out01 = -d13;
124    float out10 = d10;
125    float out11 = -d13;
126
127    out00 += d00 + d20 + k0 + k1;
128    out01 += k0 - k1 - (d03 + d23);
129    out10 += -d20 - d30 + k2 + k3;
130    out11 += k2 - k3 + d23 + d33;
131#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
132
133    int y_in  = get_global_id(1);
134    int x_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
135    int y_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
136    int z_out = get_global_id(0);
137#if defined(SRC_DEPTH)
138    int batch = get_global_id(2) / SRC_DEPTH;
139#endif /* defined(SRC_DEPTH) */
140
141#if defined(HAS_BIAS)
142    // Add bias
143    Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
144
145    float b = (float) * ((__global DATA_TYPE *)(vector_offset(&bias, z_out)));
146
147    out00 += (float)b;
148    out01 += (float)b;
149#endif // defined(HAS_BIAS)
150
151    // Get output address
152#if defined(SRC_DEPTH)
153    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * sizeof(DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z + batch * dst_stride_w;
154#else  /* defined(SRC_DEPTH) */
155    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * sizeof(DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z;
156#endif /* defined(SRC_DEPTH) */
157
158    // Store the output tile
159#if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
160    const VEC_DATA_TYPE(DATA_TYPE, 2)
161    out0_dt                                            = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, CONVERT((VEC_DATA_TYPE(float, 2))(out00, out01), VEC_DATA_TYPE(DATA_TYPE, 2)), A_VAL, B_VAL);
162    *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y)) = out0_dt.s0;
163    *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y)) = out0_dt.s1;
164#else  // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
165    vstore2(ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, CONVERT((VEC_DATA_TYPE(float, 2))(out00, out01), VEC_DATA_TYPE(DATA_TYPE, 2)), A_VAL, B_VAL), 0,
166            (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
167#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
168
169#if !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
170#if defined(HAS_BIAS)
171    // Add bias
172    out10 += (DATA_TYPE)b;
173    out11 += (DATA_TYPE)b;
174#endif // defined(HAS_BIAS)
175    vstore2(ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, CONVERT((VEC_DATA_TYPE(float, 2))(out10, out11), VEC_DATA_TYPE(DATA_TYPE, 2)), A_VAL, B_VAL), 0,
176            (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
177#endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
178}
179#endif // defined(VEC_SIZE) && VEC_SIZE == 2
180
181#if defined(VEC_SIZE) && VEC_SIZE == 4
182/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 3x3 and the data layout is NCHW
183 *
184 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
185 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
186 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
187 * @note If this kernel is used to perform Winograd output transform 3x1, -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
188 * @note If this kernel is used to perform Winograd output transform 1x3, -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
189 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
190 *
191 * @param[in]  src_ptr                           Pointer to the source tensor. Supported data types: F32/F16
192 * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
193 * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
194 * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
195 * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
196 * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
197 * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
198 * @param[in]  src_stride_w                      Stride of the source tensor in W dimension (in bytes)
199 * @param[in]  src_step_w                        src_stride_w * number of elements along W processed per workitem(in bytes)
200 * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
201 * @param[out] dst_ptr                           Pointer to the destination tensor. Supported data types: same as @p src_ptr
202 * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
203 * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
204 * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
205 * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
206 * @param[in]  dst_stride_z                      Stride of the source tensor in Z dimension (in bytes)
207 * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
208 * @param[in]  dst_stride_w                      Stride of the source tensor in W dimension (in bytes)
209 * @param[in]  dst_step_w                        dst_stride_w * number of elements along W processed per workitem(in bytes)
210 * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
211 */
212__kernel void winograd_output_transform_4x4_3x3_nchw(
213    TENSOR4D_DECLARATION(src),
214    TENSOR4D_DECLARATION(dst)
215#if defined(HAS_BIAS)
216    ,
217    VECTOR_DECLARATION(bias)
218#endif // defined(HAS_BIAS)
219)
220{
221    // Each thread stores a 4x4/4x1 or 1x4 tile
222#if defined(SRC_DEPTH)
223    Tensor4D       src             = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DEPTH);
224    const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
225#else  /* defined(SRC_DEPTH) */
226    Tensor3D       src             = CONVERT_TO_TENSOR3D_STRUCT(src);
227    const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
228#endif /* defined(SRC_DEPTH) */
229
230    // Load the values across the channels to compose the 6x6 or 6x1 tile
231    DATA_TYPE d00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z));
232    DATA_TYPE d01 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z));
233    DATA_TYPE d02 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z));
234    DATA_TYPE d03 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z));
235    DATA_TYPE d04 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z));
236    DATA_TYPE d05 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z));
237
238#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
239    // Compute out00, out01, out02 and out03
240    float out00 = d00 + d01 + d02 + d03 + d04;
241    float out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04;
242    float out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04;
243    float out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05;
244#else  // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
245
246    DATA_TYPE d10 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z));
247    DATA_TYPE d11 = *((__global DATA_TYPE *)(src_addr + 7 * src_stride_z));
248    DATA_TYPE d12 = *((__global DATA_TYPE *)(src_addr + 8 * src_stride_z));
249    DATA_TYPE d13 = *((__global DATA_TYPE *)(src_addr + 9 * src_stride_z));
250    DATA_TYPE d14 = *((__global DATA_TYPE *)(src_addr + 10 * src_stride_z));
251    DATA_TYPE d15 = *((__global DATA_TYPE *)(src_addr + 11 * src_stride_z));
252
253    DATA_TYPE d20 = *((__global DATA_TYPE *)(src_addr + 12 * src_stride_z));
254    DATA_TYPE d21 = *((__global DATA_TYPE *)(src_addr + 13 * src_stride_z));
255    DATA_TYPE d22 = *((__global DATA_TYPE *)(src_addr + 14 * src_stride_z));
256    DATA_TYPE d23 = *((__global DATA_TYPE *)(src_addr + 15 * src_stride_z));
257    DATA_TYPE d24 = *((__global DATA_TYPE *)(src_addr + 16 * src_stride_z));
258    DATA_TYPE d25 = *((__global DATA_TYPE *)(src_addr + 17 * src_stride_z));
259
260    DATA_TYPE d30 = *((__global DATA_TYPE *)(src_addr + 18 * src_stride_z));
261    DATA_TYPE d31 = *((__global DATA_TYPE *)(src_addr + 19 * src_stride_z));
262    DATA_TYPE d32 = *((__global DATA_TYPE *)(src_addr + 20 * src_stride_z));
263    DATA_TYPE d33 = *((__global DATA_TYPE *)(src_addr + 21 * src_stride_z));
264    DATA_TYPE d34 = *((__global DATA_TYPE *)(src_addr + 22 * src_stride_z));
265    DATA_TYPE d35 = *((__global DATA_TYPE *)(src_addr + 23 * src_stride_z));
266
267    DATA_TYPE d40 = *((__global DATA_TYPE *)(src_addr + 24 * src_stride_z));
268    DATA_TYPE d41 = *((__global DATA_TYPE *)(src_addr + 25 * src_stride_z));
269    DATA_TYPE d42 = *((__global DATA_TYPE *)(src_addr + 26 * src_stride_z));
270    DATA_TYPE d43 = *((__global DATA_TYPE *)(src_addr + 27 * src_stride_z));
271    DATA_TYPE d44 = *((__global DATA_TYPE *)(src_addr + 28 * src_stride_z));
272    DATA_TYPE d45 = *((__global DATA_TYPE *)(src_addr + 29 * src_stride_z));
273
274    DATA_TYPE d50 = *((__global DATA_TYPE *)(src_addr + 30 * src_stride_z));
275    DATA_TYPE d51 = *((__global DATA_TYPE *)(src_addr + 31 * src_stride_z));
276    DATA_TYPE d52 = *((__global DATA_TYPE *)(src_addr + 32 * src_stride_z));
277    DATA_TYPE d53 = *((__global DATA_TYPE *)(src_addr + 33 * src_stride_z));
278    DATA_TYPE d54 = *((__global DATA_TYPE *)(src_addr + 34 * src_stride_z));
279    DATA_TYPE d55 = *((__global DATA_TYPE *)(src_addr + 35 * src_stride_z));
280
281    // Compute out00, out01, out02 and out03
282    float out00 = (float)d01 + (float)d21 + (float)d41 + (float)d11 + (float)d31;
283    float out01 = (float)d01 + (float)d21 + (float)d41 + (float)d11 + (float)d31;
284    float out02 = (float)d01 + (float)d21 + (float)d41 + (float)d11 + (float)d31;
285    float out03 = (float)d01 + d21 + (float)d41 + (float)d11 + (float)d31;
286
287    float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
288    float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
289
290    out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42;
291    out01 += k1 - d02 - d12 - d22 - d32 - d42;
292    out02 += 4.0f * k0 + d02 + d12 + d22 + d32 + d42;
293    out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45;
294
295    // Compute out10, out11, out12 and out13
296    float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
297    float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
298    float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
299    float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
300
301    k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44;
302    k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44;
303
304    out10 += k0 + d10 + d12 - d20 - d22 + 2.0f * d30 + 2.0f * d32 - 2.0f * d40 - 2.0f * d42;
305    out11 += k1 - d12 + d22 - 2.0f * d32 + 2.0f * d42;
306    out12 += 4.0f * k0 + d12 - d22 + 2.0f * d32 - 2.0f * d42;
307    out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45;
308
309    // Compute out20, out21, out22 and out23
310    float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
311    float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
312    float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
313    float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
314
315    k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44;
316    k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44;
317
318    out20 += k0 + d10 + d12 + d20 + d22 + 4.0f * d30 + 4.0f * d32 + 4.0f * d40 + 4.0f * d42;
319    out21 += k1 - d12 - d22 - 4.0f * d32 - 4.0f * d42;
320    out22 += 4.0f * k0 + d12 + d22 + 4.0f * d32 + 4.0f * d42;
321    out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45;
322
323    // Compute out30, out31, out32 and out33
324    float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
325    float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
326    float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
327    float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
328
329    k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54;
330    k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54;
331
332    out30 += k0 + d10 + d12 - d20 - d22 + 8.0f * d30 + 8.0f * d32 - 8.0f * d40 - 8.0f * d42 + d50 + d52;
333    out31 += k1 - d12 + d22 - 8.0f * d32 + 8.0f * d42 - d52;
334    out32 += 4.0f * k0 + d12 - d22 + 8.0f * d32 - 8.0f * d42 + d52;
335    out33 += 4.0f * k1 - d12 + d15 + d22 - d25 - 8.0f * d32 + 8.0f * d35 + 8.0f * d42 - 8.0f * d45 - d52 + d55;
336#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
337
338    int y_in  = get_global_id(1);
339    int x_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
340    int y_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
341    int z_out = get_global_id(0);
342#if defined(SRC_DEPTH)
343    int batch = get_global_id(2) / SRC_DEPTH;
344#endif /* defined(SRC_DEPTH) */
345
346#if defined(HAS_BIAS)
347    // Add bias
348    Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
349
350    float b = (float) * ((__global DATA_TYPE *)(vector_offset(&bias, z_out)));
351
352    out00 += (float)b;
353    out01 += (float)b;
354    out02 += (float)b;
355    out03 += (float)b;
356#endif // defined(HAS_BIAS)
357
358    // Get output address
359#if defined(SRC_DEPTH)
360    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * sizeof(DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z + batch * dst_stride_w;
361#else  /* defined(SRC_DEPTH) */
362    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * sizeof(DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z;
363#endif /* defined(SRC_DEPTH) */
364
365    // Store the output tile
366    const VEC_DATA_TYPE(DATA_TYPE, 4)
367    out0_dt = CONVERT(ACTIVATION(ACTIVATION_TYPE, float, VEC_SIZE, (VEC_DATA_TYPE(float, 4))(out00, out01, out02, out03), A_VAL, B_VAL), VEC_DATA_TYPE(DATA_TYPE, 4));
368
369#if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
370    *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y)) = out0_dt.s0;
371    *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y)) = out0_dt.s1;
372    *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y)) = out0_dt.s2;
373    *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y)) = out0_dt.s3;
374#else  // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
375    vstore4(out0_dt, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
376#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
377
378#if !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
379#if defined(HAS_BIAS)
380    // Add bias
381    out10 += (float)b;
382    out11 += (float)b;
383    out12 += (float)b;
384    out13 += (float)b;
385
386    out20 += (float)b;
387    out21 += (float)b;
388    out22 += (float)b;
389    out23 += (float)b;
390
391    out30 += (float)b;
392    out31 += (float)b;
393    out32 += (float)b;
394    out33 += (float)b;
395#endif // defined(HAS_BIAS)
396    vstore4(CONVERT(ACTIVATION(ACTIVATION_TYPE, float, VEC_SIZE, (VEC_DATA_TYPE(float, 4))(out10, out11, out12, out13), A_VAL, B_VAL), VEC_DATA_TYPE(DATA_TYPE, 4)), 0,
397            (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
398    vstore4(CONVERT(ACTIVATION(ACTIVATION_TYPE, float, VEC_SIZE, (VEC_DATA_TYPE(float, 4))(out20, out21, out22, out23), A_VAL, B_VAL), VEC_DATA_TYPE(DATA_TYPE, 4)), 0,
399            (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
400    vstore4(CONVERT(ACTIVATION(ACTIVATION_TYPE, float, VEC_SIZE, (VEC_DATA_TYPE(float, 4))(out30, out31, out32, out33), A_VAL, B_VAL), VEC_DATA_TYPE(DATA_TYPE, 4)), 0,
401            (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
402#endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
403}
404
405#define COMPUTE_TMP_COL(col, d0, d1, d2, d3, d4, d5, d6, d7, comm_fact)  \
406    ({                                                                   \
407        comm_fact.s0 = d1 + d2;                                          \
408        comm_fact.s1 = d3 + d4;                                          \
409        comm_fact.s2 = d5 + d6;                                          \
410        \
411        col.s0 = comm_fact.s0 + comm_fact.s1 + 8.f * comm_fact.s2 + d0;  \
412        col.s2 = comm_fact.s0 + 4.f * comm_fact.s1 + 2.f * comm_fact.s2; \
413        \
414        comm_fact.s0 = d1 - d2;                                          \
415        comm_fact.s1 = d3 - d4;                                          \
416        comm_fact.s2 = d5 - d6;                                          \
417        \
418        col.s1 = comm_fact.s0 + 2.f * comm_fact.s1 + 4.f * comm_fact.s2; \
419        col.s3 = comm_fact.s0 + 8.f * comm_fact.s1 + comm_fact.s2 + d7;  \
420    })
421
422/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4/4x1 or 1x4, the filter size 5x5/5x1 or 1x5 and the data layout is NCHW
423 *
424 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
425 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
426 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
427 * @note If this kernel is used to perform Winograd output transform 3x1, -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
428 * @note If this kernel is used to perform Winograd output transform 1x3, -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
429 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
430 *
431 * @param[in]  src_ptr                           Pointer to the source tensor. Supported data types: F32/F16
432 * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
433 * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
434 * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
435 * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
436 * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
437 * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
438 * @param[in]  src_stride_w                      Stride of the source tensor in W dimension (in bytes)
439 * @param[in]  src_step_w                        src_stride_w * number of elements along W processed per workitem(in bytes)
440 * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
441 * @param[out] dst_ptr                           Pointer to the destination tensor. Supported data types: same as @p src_ptr
442 * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
443 * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
444 * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
445 * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
446 * @param[in]  dst_stride_z                      Stride of the source tensor in Z dimension (in bytes)
447 * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
448 * @param[in]  dst_stride_w                      Stride of the source tensor in W dimension (in bytes)
449 * @param[in]  dst_step_w                        dst_stride_w * number of elements along W processed per workitem(in bytes)
450 * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
451 */
452__kernel void winograd_output_transform_4x4_5x5_nchw(
453    TENSOR4D_DECLARATION(src),
454    TENSOR4D_DECLARATION(dst)
455#if defined(HAS_BIAS)
456    ,
457    VECTOR_DECLARATION(bias)
458#endif // defined(HAS_BIAS)
459)
460{
461    // Each thread stores a 4x4/4x1 or 1x4 tile
462#if defined(SRC_DEPTH)
463    Tensor4D       src             = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DEPTH);
464    const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
465#else  /* defined(SRC_DEPTH) */
466
467    Tensor3D       src             = CONVERT_TO_TENSOR3D_STRUCT(src);
468    const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
469#endif /* defined(SRC_DEPTH) */
470
471    // Compute output address
472    int y_in  = get_global_id(1);
473    int x_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
474    int y_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
475    int z_out = get_global_id(0);
476#if defined(SRC_DEPTH)
477    int batch = get_global_id(2) / SRC_DEPTH;
478#endif /* defined(SRC_DEPTH) */
479
480#if defined(SRC_DEPTH)
481    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * sizeof(DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z + batch * dst_stride_w;
482#else  /* defined(SRC_DEPTH) */
483
484    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * sizeof(DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z;
485#endif /* defined(SRC_DEPTH) */
486
487    // Load the values across the channels to compose the input tile
488    DATA_TYPE d00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z));
489    DATA_TYPE d01 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z));
490    DATA_TYPE d02 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z));
491    DATA_TYPE d03 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z));
492    DATA_TYPE d04 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z));
493    DATA_TYPE d05 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z));
494    DATA_TYPE d06 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z));
495    DATA_TYPE d07 = *((__global DATA_TYPE *)(src_addr + 7 * src_stride_z));
496
497#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
498    // Compute out00, out01, out02 and out03
499    float out00 = d00 + d01 + d02 + d03 + d04 + 8.0f * d05 + 8.0f * d06;
500    float out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04 + 4.0f * d05 - 4.0f * d06;
501    float out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04 + 2.0f * d05 + 2.0f * d06;
502    float out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05 - d06 + d07;
503
504#if defined(HAS_BIAS)
505    // Add bias
506    Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
507
508    float b = (float) * ((__global DATA_TYPE *)(vector_offset(&bias, z_out)));
509
510    out00 += (DATA_TYPE)b;
511    out01 += (DATA_TYPE)b;
512    out02 += (DATA_TYPE)b;
513    out03 += (DATA_TYPE)b;
514#endif // defined(HAS_BIAS)
515
516    // Store the output tile
517#if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
518    VEC_DATA_TYPE(DATA_TYPE, 4)
519    out0_dt = CONVERT(ACTIVATION(ACTIVATION_TYPE, float, VEC_SIZE, (VEC_DATA_TYPE(float, 4))(out00, out01, out02, out03), A_VAL,
520                                 B_VAL),
521                      VEC_DATA_TYPE(DATA_TYPE, 4));
522    *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y)) = out0_dt.s0;
523    *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y)) = out0_dt.s1;
524    *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y)) = out0_dt.s2;
525    *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y)) = out0_dt.s3;
526#else  // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
527    vstore4(CONVERT(ACTIVATION(ACTIVATION_TYPE, float, VEC_SIZE, (VEC_DATA_TYPE(float, 4))(out00, out01, out02, out03), A_VAL, B_VAL), VEC_DATA_TYPE(DATA_TYPE, 4)),
528            0, (__global DATA_TYPE *)(dst_addr));
529#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
530
531#else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
532
533    DATA_TYPE d10 = *((__global DATA_TYPE *)(src_addr + 8 * src_stride_z));
534    DATA_TYPE d11 = *((__global DATA_TYPE *)(src_addr + 9 * src_stride_z));
535    DATA_TYPE d12 = *((__global DATA_TYPE *)(src_addr + 10 * src_stride_z));
536    DATA_TYPE d13 = *((__global DATA_TYPE *)(src_addr + 11 * src_stride_z));
537    DATA_TYPE d14 = *((__global DATA_TYPE *)(src_addr + 12 * src_stride_z));
538    DATA_TYPE d15 = *((__global DATA_TYPE *)(src_addr + 13 * src_stride_z));
539    DATA_TYPE d16 = *((__global DATA_TYPE *)(src_addr + 14 * src_stride_z));
540    DATA_TYPE d17 = *((__global DATA_TYPE *)(src_addr + 15 * src_stride_z));
541
542    DATA_TYPE d20 = *((__global DATA_TYPE *)(src_addr + 16 * src_stride_z));
543    DATA_TYPE d21 = *((__global DATA_TYPE *)(src_addr + 17 * src_stride_z));
544    DATA_TYPE d22 = *((__global DATA_TYPE *)(src_addr + 18 * src_stride_z));
545    DATA_TYPE d23 = *((__global DATA_TYPE *)(src_addr + 19 * src_stride_z));
546    DATA_TYPE d24 = *((__global DATA_TYPE *)(src_addr + 20 * src_stride_z));
547    DATA_TYPE d25 = *((__global DATA_TYPE *)(src_addr + 21 * src_stride_z));
548    DATA_TYPE d26 = *((__global DATA_TYPE *)(src_addr + 22 * src_stride_z));
549    DATA_TYPE d27 = *((__global DATA_TYPE *)(src_addr + 23 * src_stride_z));
550
551    DATA_TYPE d30 = *((__global DATA_TYPE *)(src_addr + 24 * src_stride_z));
552    DATA_TYPE d31 = *((__global DATA_TYPE *)(src_addr + 25 * src_stride_z));
553    DATA_TYPE d32 = *((__global DATA_TYPE *)(src_addr + 26 * src_stride_z));
554    DATA_TYPE d33 = *((__global DATA_TYPE *)(src_addr + 27 * src_stride_z));
555    DATA_TYPE d34 = *((__global DATA_TYPE *)(src_addr + 28 * src_stride_z));
556    DATA_TYPE d35 = *((__global DATA_TYPE *)(src_addr + 29 * src_stride_z));
557    DATA_TYPE d36 = *((__global DATA_TYPE *)(src_addr + 30 * src_stride_z));
558    DATA_TYPE d37 = *((__global DATA_TYPE *)(src_addr + 31 * src_stride_z));
559
560    DATA_TYPE d40 = *((__global DATA_TYPE *)(src_addr + 32 * src_stride_z));
561    DATA_TYPE d41 = *((__global DATA_TYPE *)(src_addr + 33 * src_stride_z));
562    DATA_TYPE d42 = *((__global DATA_TYPE *)(src_addr + 34 * src_stride_z));
563    DATA_TYPE d43 = *((__global DATA_TYPE *)(src_addr + 35 * src_stride_z));
564    DATA_TYPE d44 = *((__global DATA_TYPE *)(src_addr + 36 * src_stride_z));
565    DATA_TYPE d45 = *((__global DATA_TYPE *)(src_addr + 37 * src_stride_z));
566    DATA_TYPE d46 = *((__global DATA_TYPE *)(src_addr + 38 * src_stride_z));
567    DATA_TYPE d47 = *((__global DATA_TYPE *)(src_addr + 39 * src_stride_z));
568
569    DATA_TYPE d50 = *((__global DATA_TYPE *)(src_addr + 40 * src_stride_z));
570    DATA_TYPE d51 = *((__global DATA_TYPE *)(src_addr + 41 * src_stride_z));
571    DATA_TYPE d52 = *((__global DATA_TYPE *)(src_addr + 42 * src_stride_z));
572    DATA_TYPE d53 = *((__global DATA_TYPE *)(src_addr + 43 * src_stride_z));
573    DATA_TYPE d54 = *((__global DATA_TYPE *)(src_addr + 44 * src_stride_z));
574    DATA_TYPE d55 = *((__global DATA_TYPE *)(src_addr + 45 * src_stride_z));
575    DATA_TYPE d56 = *((__global DATA_TYPE *)(src_addr + 46 * src_stride_z));
576    DATA_TYPE d57 = *((__global DATA_TYPE *)(src_addr + 47 * src_stride_z));
577
578    DATA_TYPE d60 = *((__global DATA_TYPE *)(src_addr + 48 * src_stride_z));
579    DATA_TYPE d61 = *((__global DATA_TYPE *)(src_addr + 49 * src_stride_z));
580    DATA_TYPE d62 = *((__global DATA_TYPE *)(src_addr + 50 * src_stride_z));
581    DATA_TYPE d63 = *((__global DATA_TYPE *)(src_addr + 51 * src_stride_z));
582    DATA_TYPE d64 = *((__global DATA_TYPE *)(src_addr + 52 * src_stride_z));
583    DATA_TYPE d65 = *((__global DATA_TYPE *)(src_addr + 53 * src_stride_z));
584    DATA_TYPE d66 = *((__global DATA_TYPE *)(src_addr + 54 * src_stride_z));
585    DATA_TYPE d67 = *((__global DATA_TYPE *)(src_addr + 55 * src_stride_z));
586
587    DATA_TYPE d70 = *((__global DATA_TYPE *)(src_addr + 56 * src_stride_z));
588    DATA_TYPE d71 = *((__global DATA_TYPE *)(src_addr + 57 * src_stride_z));
589    DATA_TYPE d72 = *((__global DATA_TYPE *)(src_addr + 58 * src_stride_z));
590    DATA_TYPE d73 = *((__global DATA_TYPE *)(src_addr + 59 * src_stride_z));
591    DATA_TYPE d74 = *((__global DATA_TYPE *)(src_addr + 60 * src_stride_z));
592    DATA_TYPE d75 = *((__global DATA_TYPE *)(src_addr + 61 * src_stride_z));
593    DATA_TYPE d76 = *((__global DATA_TYPE *)(src_addr + 62 * src_stride_z));
594    DATA_TYPE d77 = *((__global DATA_TYPE *)(src_addr + 63 * src_stride_z));
595
596    // Compute the 8x4 intermediate tensor
597    VEC_DATA_TYPE(float, 4)
598    comm_fact0, comm_fact1, comm_fact2;
599    VEC_DATA_TYPE(float, 4)
600    tmp_col0, tmp_col1, tmp_col2, tmp_col3, tmp_col4, tmp_col5, tmp_col6, tmp_col7;
601
602    COMPUTE_TMP_COL(tmp_col0, d00, d10, d20, d30, d40, d50, d60, d70, comm_fact0);
603    COMPUTE_TMP_COL(tmp_col1, d01, d11, d21, d31, d41, d51, d61, d71, comm_fact0);
604    COMPUTE_TMP_COL(tmp_col2, d02, d12, d22, d32, d42, d52, d62, d72, comm_fact0);
605    COMPUTE_TMP_COL(tmp_col3, d03, d13, d23, d33, d43, d53, d63, d73, comm_fact0);
606    COMPUTE_TMP_COL(tmp_col4, d04, d14, d24, d34, d44, d54, d64, d74, comm_fact0);
607    COMPUTE_TMP_COL(tmp_col5, d05, d15, d25, d35, d45, d55, d65, d75, comm_fact0);
608    COMPUTE_TMP_COL(tmp_col6, d06, d16, d26, d36, d46, d56, d66, d76, comm_fact0);
609    COMPUTE_TMP_COL(tmp_col7, d07, d17, d27, d37, d47, d57, d67, d77, comm_fact0);
610
611    // Compute the 4x4 output tile
612    comm_fact0 = tmp_col1 + tmp_col2;
613    comm_fact1 = tmp_col3 + tmp_col4;
614    comm_fact2 = tmp_col5 + tmp_col6;
615
616    VEC_DATA_TYPE(float, 4)
617    out_col0 = comm_fact0 + comm_fact1 + (float)8.f * comm_fact2 + tmp_col0;
618    VEC_DATA_TYPE(float, 4)
619    out_col2 = comm_fact0 + (float)4.f * comm_fact1 + (float)2.f * comm_fact2;
620
621    comm_fact0 = tmp_col1 - tmp_col2;
622    comm_fact1 = tmp_col3 - tmp_col4;
623    comm_fact2 = tmp_col5 - tmp_col6;
624
625    VEC_DATA_TYPE(float, 4)
626    out_col1 = comm_fact0 + (float)2.f * comm_fact1 + (float)4.f * comm_fact2;
627    VEC_DATA_TYPE(float, 4)
628    out_col3 = comm_fact0 + (float)8.f * comm_fact1 + comm_fact2 + tmp_col7;
629
630#if defined(HAS_BIAS)
631    // Add bias
632    Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
633
634    float b = (float) * ((__global DATA_TYPE *)(vector_offset(&bias, z_out)));
635
636    out_col0 += (VEC_DATA_TYPE(float, 4))b;
637    out_col1 += (VEC_DATA_TYPE(float, 4))b;
638    out_col2 += (VEC_DATA_TYPE(float, 4))b;
639    out_col3 += (VEC_DATA_TYPE(float, 4))b;
640#endif // defined(HAS_BIAS)
641
642    // Store the output tile
643    vstore4(CONVERT(ACTIVATION(ACTIVATION_TYPE, float, VEC_SIZE, (VEC_DATA_TYPE(float, 4))(out_col0.s0, out_col1.s0, out_col2.s0, out_col3.s0), A_VAL, B_VAL),
644                    VEC_DATA_TYPE(DATA_TYPE, 4)),
645            0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
646    vstore4(CONVERT(ACTIVATION(ACTIVATION_TYPE, float, VEC_SIZE, (VEC_DATA_TYPE(float, 4))(out_col0.s1, out_col1.s1, out_col2.s1, out_col3.s1), A_VAL, B_VAL),
647                    VEC_DATA_TYPE(DATA_TYPE, 4)),
648            0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
649    vstore4(CONVERT(ACTIVATION(ACTIVATION_TYPE, float, VEC_SIZE, (VEC_DATA_TYPE(float, 4))(out_col0.s2, out_col1.s2, out_col2.s2, out_col3.s2), A_VAL, B_VAL),
650                    VEC_DATA_TYPE(DATA_TYPE, 4)),
651            0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
652    vstore4(CONVERT(ACTIVATION(ACTIVATION_TYPE, float, VEC_SIZE, (VEC_DATA_TYPE(float, 4))(out_col0.s3, out_col1.s3, out_col2.s3, out_col3.s3), A_VAL, B_VAL),
653                    VEC_DATA_TYPE(DATA_TYPE, 4)),
654            0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
655#endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
656}
657#endif // defined(VEC_SIZE) && VEC_SIZE == 4
658
659#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL)
660#if defined(VEC_SIZE) && VEC_SIZE == 2
661/** This OpenCL kernel performs Winograd output transform when the output tile is 2x1, the filter size 3x1 and the data layout is NCHW
662 *
663 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
664 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
665 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
666 * @note -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
667 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
668 *
669 * @param[in]  src_ptr                           Pointer to the source tensor. Supported data types: F32/F16
670 * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
671 * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
672 * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
673 * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
674 * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
675 * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
676 * @param[in]  src_stride_w                      Stride of the source tensor in W dimension (in bytes)
677 * @param[in]  src_step_w                        src_stride_w * number of elements along W processed per workitem(in bytes)
678 * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
679 * @param[out] dst_ptr                           Pointer to the destination tensor. Supported data types: same as @p src_ptr
680 * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
681 * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
682 * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
683 * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
684 * @param[in]  dst_stride_z                      Stride of the source tensor in Z dimension (in bytes)
685 * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
686 * @param[in]  dst_stride_w                      Stride of the source tensor in W dimension (in bytes)
687 * @param[in]  dst_step_w                        dst_stride_w * number of elements along W processed per workitem(in bytes)
688 * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
689 */
690__kernel void winograd_output_transform_2x1_3x1_nchw(
691    TENSOR4D_DECLARATION(src),
692    TENSOR4D_DECLARATION(dst)
693#if defined(HAS_BIAS)
694    ,
695    VECTOR_DECLARATION(bias)
696#endif // defined(HAS_BIAS)
697)
698{
699    winograd_output_transform_2x2_3x3_nchw(src_ptr,
700                                           src_stride_x,
701                                           src_step_x,
702                                           src_stride_y,
703                                           src_step_y,
704                                           src_stride_z,
705                                           src_step_z,
706                                           src_stride_w,
707                                           src_step_w,
708                                           src_offset_first_element_in_bytes,
709                                           dst_ptr,
710                                           dst_stride_x,
711                                           dst_step_x,
712                                           dst_stride_y,
713                                           dst_step_y,
714                                           dst_stride_z,
715                                           dst_step_z,
716                                           dst_stride_w,
717                                           dst_step_w,
718                                           dst_offset_first_element_in_bytes
719#if defined(HAS_BIAS)
720                                           ,
721                                           bias_ptr,
722                                           bias_stride_x,
723                                           bias_step_x,
724                                           bias_offset_first_element_in_bytes
725#endif // defined(HAS_BIAS)
726                                          );
727}
728
729#endif // defined(VEC_SIZE) && VEC_SIZE == 2
730
731#if defined(VEC_SIZE) && VEC_SIZE == 4
732/** This OpenCL kernel performs Winograd output transform when the output tile is 4x1, the filter size 3x1 and the data layout is NCHW
733 *
734 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
735 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
736 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
737 * @note -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
738 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
739 *
740 * @param[in]  src_ptr                           Pointer to the source tensor. Supported data types: F32/F16
741 * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
742 * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
743 * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
744 * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
745 * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
746 * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
747 * @param[in]  src_stride_w                      Stride of the source tensor in W dimension (in bytes)
748 * @param[in]  src_step_w                        src_stride_w * number of elements along W processed per workitem(in bytes)
749 * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
750 * @param[out] dst_ptr                           Pointer to the destination tensor. Supported data types: same as @p src_ptr
751 * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
752 * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
753 * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
754 * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
755 * @param[in]  dst_stride_z                      Stride of the source tensor in Z dimension (in bytes)
756 * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
757 * @param[in]  dst_stride_w                      Stride of the source tensor in W dimension (in bytes)
758 * @param[in]  dst_step_w                        dst_stride_w * number of elements along W processed per workitem(in bytes)
759 * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
760 */
761__kernel void winograd_output_transform_4x1_3x1_nchw(
762    TENSOR4D_DECLARATION(src),
763    TENSOR4D_DECLARATION(dst)
764#if defined(HAS_BIAS)
765    ,
766    VECTOR_DECLARATION(bias)
767#endif // defined(HAS_BIAS)
768)
769{
770    winograd_output_transform_4x4_3x3_nchw(src_ptr,
771                                           src_stride_x,
772                                           src_step_x,
773                                           src_stride_y,
774                                           src_step_y,
775                                           src_stride_z,
776                                           src_step_z,
777                                           src_stride_w,
778                                           src_step_w,
779                                           src_offset_first_element_in_bytes,
780                                           dst_ptr,
781                                           dst_stride_x,
782                                           dst_step_x,
783                                           dst_stride_y,
784                                           dst_step_y,
785                                           dst_stride_z,
786                                           dst_step_z,
787                                           dst_stride_w,
788                                           dst_step_w,
789                                           dst_offset_first_element_in_bytes
790#if defined(HAS_BIAS)
791                                           ,
792                                           bias_ptr,
793                                           bias_stride_x,
794                                           bias_step_x,
795                                           bias_offset_first_element_in_bytes
796#endif // defined(HAS_BIAS)
797                                          );
798}
799
800/** This OpenCL kernel performs Winograd output transform when the output tile is 4x1, the filter size 5x1 and the data layout is NCHW
801 *
802 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
803 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
804 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
805 * @note -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
806 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
807 *
808 * @param[in]  src_ptr                           Pointer to the source tensor. Supported data types: F32/F16
809 * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
810 * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
811 * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
812 * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
813 * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
814 * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
815 * @param[in]  src_stride_w                      Stride of the source tensor in W dimension (in bytes)
816 * @param[in]  src_step_w                        src_stride_w * number of elements along W processed per workitem(in bytes)
817 * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
818 * @param[out] dst_ptr                           Pointer to the destination tensor. Supported data types: same as @p src_ptr
819 * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
820 * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
821 * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
822 * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
823 * @param[in]  dst_stride_z                      Stride of the source tensor in Z dimension (in bytes)
824 * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
825 * @param[in]  dst_stride_w                      Stride of the source tensor in W dimension (in bytes)
826 * @param[in]  dst_step_w                        dst_stride_w * number of elements along W processed per workitem(in bytes)
827 * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
828 */
829__kernel void winograd_output_transform_4x1_5x1_nchw(
830    TENSOR4D_DECLARATION(src),
831    TENSOR4D_DECLARATION(dst)
832#if defined(HAS_BIAS)
833    ,
834    VECTOR_DECLARATION(bias)
835#endif // defined(HAS_BIAS)
836)
837{
838    winograd_output_transform_4x4_5x5_nchw(src_ptr,
839                                           src_stride_x,
840                                           src_step_x,
841                                           src_stride_y,
842                                           src_step_y,
843                                           src_stride_z,
844                                           src_step_z,
845                                           src_stride_w,
846                                           src_step_w,
847                                           src_offset_first_element_in_bytes,
848                                           dst_ptr,
849                                           dst_stride_x,
850                                           dst_step_x,
851                                           dst_stride_y,
852                                           dst_step_y,
853                                           dst_stride_z,
854                                           dst_step_z,
855                                           dst_stride_w,
856                                           dst_step_w,
857                                           dst_offset_first_element_in_bytes
858#if defined(HAS_BIAS)
859                                           ,
860                                           bias_ptr,
861                                           bias_stride_x,
862                                           bias_step_x,
863                                           bias_offset_first_element_in_bytes
864#endif // defined(HAS_BIAS)
865                                          );
866}
867
868#endif // defined(VEC_SIZE) && VEC_SIZE == 4
869#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL)
870
871#if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
872#if defined(VEC_SIZE) && VEC_SIZE == 2
873/** This OpenCL kernel performs Winograd output transform when the output tile is 1x2, the filter size 1x3 and the data layout is NCHW
874 *
875 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
876 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
877 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
878 * @note -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
879 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
880 *
881 * @param[in]  src_ptr                           Pointer to the source tensor. Supported data types: F32/F16
882 * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
883 * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
884 * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
885 * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
886 * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
887 * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
888 * @param[in]  src_stride_w                      Stride of the source tensor in W dimension (in bytes)
889 * @param[in]  src_step_w                        src_stride_w * number of elements along W processed per workitem(in bytes)
890 * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
891 * @param[out] dst_ptr                           Pointer to the destination tensor. Supported data types: same as @p src_ptr
892 * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
893 * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
894 * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
895 * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
896 * @param[in]  dst_stride_z                      Stride of the source tensor in Z dimension (in bytes)
897 * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
898 * @param[in]  dst_stride_w                      Stride of the source tensor in W dimension (in bytes)
899 * @param[in]  dst_step_w                        dst_stride_w * number of elements along W processed per workitem(in bytes)
900 * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
901 */
902__kernel void winograd_output_transform_1x2_1x3_nchw(
903    TENSOR4D_DECLARATION(src),
904    TENSOR4D_DECLARATION(dst)
905#if defined(HAS_BIAS)
906    ,
907    VECTOR_DECLARATION(bias)
908#endif // defined(HAS_BIAS)
909)
910{
911    winograd_output_transform_2x2_3x3_nchw(src_ptr,
912                                           src_stride_x,
913                                           src_step_x,
914                                           src_stride_y,
915                                           src_step_y,
916                                           src_stride_z,
917                                           src_step_z,
918                                           src_stride_w,
919                                           src_step_w,
920                                           src_offset_first_element_in_bytes,
921                                           dst_ptr,
922                                           dst_stride_x,
923                                           dst_step_x,
924                                           dst_stride_y,
925                                           dst_step_y,
926                                           dst_stride_z,
927                                           dst_step_z,
928                                           dst_stride_w,
929                                           dst_step_w,
930                                           dst_offset_first_element_in_bytes
931#if defined(HAS_BIAS)
932                                           ,
933                                           bias_ptr,
934                                           bias_stride_x,
935                                           bias_step_x,
936                                           bias_offset_first_element_in_bytes
937#endif // defined(HAS_BIAS)
938                                          );
939}
940
941#endif // defined(VEC_SIZE) && VEC_SIZE == 2
942
943#if defined(VEC_SIZE) && VEC_SIZE == 4
944/** This OpenCL kernel performs Winograd output transform when the output tile is 1x4, the filter size 1x3 and the data layout is NCHW
945 *
946 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
947 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
948 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
949 * @note -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
950 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
951 *
952 * @param[in]  src_ptr                           Pointer to the source tensor. Supported data types: F32/F16
953 * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
954 * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
955 * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
956 * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
957 * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
958 * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
959 * @param[in]  src_stride_w                      Stride of the source tensor in W dimension (in bytes)
960 * @param[in]  src_step_w                        src_stride_w * number of elements along W processed per workitem(in bytes)
961 * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
962 * @param[out] dst_ptr                           Pointer to the destination tensor. Supported data types: same as @p src_ptr
963 * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
964 * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
965 * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
966 * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
967 * @param[in]  dst_stride_z                      Stride of the source tensor in Z dimension (in bytes)
968 * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
969 * @param[in]  dst_stride_w                      Stride of the source tensor in W dimension (in bytes)
970 * @param[in]  dst_step_w                        dst_stride_w * number of elements along W processed per workitem(in bytes)
971 * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
972 */
973__kernel void winograd_output_transform_1x4_1x3_nchw(
974    TENSOR4D_DECLARATION(src),
975    TENSOR4D_DECLARATION(dst)
976#if defined(HAS_BIAS)
977    ,
978    VECTOR_DECLARATION(bias)
979#endif // defined(HAS_BIAS)
980)
981{
982    winograd_output_transform_4x4_3x3_nchw(src_ptr,
983                                           src_stride_x,
984                                           src_step_x,
985                                           src_stride_y,
986                                           src_step_y,
987                                           src_stride_z,
988                                           src_step_z,
989                                           src_stride_w,
990                                           src_step_w,
991                                           src_offset_first_element_in_bytes,
992                                           dst_ptr,
993                                           dst_stride_x,
994                                           dst_step_x,
995                                           dst_stride_y,
996                                           dst_step_y,
997                                           dst_stride_z,
998                                           dst_step_z,
999                                           dst_stride_w,
1000                                           dst_step_w,
1001                                           dst_offset_first_element_in_bytes
1002#if defined(HAS_BIAS)
1003                                           ,
1004                                           bias_ptr,
1005                                           bias_stride_x,
1006                                           bias_step_x,
1007                                           bias_offset_first_element_in_bytes
1008#endif // defined(HAS_BIAS)
1009                                          );
1010}
1011
1012/** This OpenCL kernel performs Winograd output transform when the output tile is 1x4, the filter size 1x5 and the data layout is NCHW
1013 *
1014 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
1015 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
1016 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
1017 * @note -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
1018 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
1019 *
1020 * @param[in]  src_ptr                           Pointer to the source tensor. Supported data types: F32/F16
1021 * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
1022 * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
1023 * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
1024 * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
1025 * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
1026 * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
1027 * @param[in]  src_stride_w                      Stride of the source tensor in W dimension (in bytes)
1028 * @param[in]  src_step_w                        src_stride_w * number of elements along W processed per workitem(in bytes)
1029 * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
1030 * @param[out] dst_ptr                           Pointer to the destination tensor. Supported data types: same as @p src_ptr
1031 * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
1032 * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
1033 * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
1034 * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
1035 * @param[in]  dst_stride_z                      Stride of the source tensor in Z dimension (in bytes)
1036 * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
1037 * @param[in]  dst_stride_w                      Stride of the source tensor in W dimension (in bytes)
1038 * @param[in]  dst_step_w                        dst_stride_w * number of elements along W processed per workitem(in bytes)
1039 * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1040 */
1041__kernel void winograd_output_transform_1x4_1x5_nchw(
1042    TENSOR4D_DECLARATION(src),
1043    TENSOR4D_DECLARATION(dst)
1044#if defined(HAS_BIAS)
1045    ,
1046    VECTOR_DECLARATION(bias)
1047#endif // defined(HAS_BIAS)
1048)
1049{
1050    winograd_output_transform_4x4_5x5_nchw(src_ptr,
1051                                           src_stride_x,
1052                                           src_step_x,
1053                                           src_stride_y,
1054                                           src_step_y,
1055                                           src_stride_z,
1056                                           src_step_z,
1057                                           src_stride_w,
1058                                           src_step_w,
1059                                           src_offset_first_element_in_bytes,
1060                                           dst_ptr,
1061                                           dst_stride_x,
1062                                           dst_step_x,
1063                                           dst_stride_y,
1064                                           dst_step_y,
1065                                           dst_stride_z,
1066                                           dst_step_z,
1067                                           dst_stride_w,
1068                                           dst_step_w,
1069                                           dst_offset_first_element_in_bytes
1070#if defined(HAS_BIAS)
1071                                           ,
1072                                           bias_ptr,
1073                                           bias_stride_x,
1074                                           bias_step_x,
1075                                           bias_offset_first_element_in_bytes
1076#endif // defined(HAS_BIAS)
1077                                          );
1078}
1079
1080#endif // defined(VEC_SIZE) && VEC_SIZE == 4
1081#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
1082#endif // defined(NUM_TILES_X) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
1083