xref: /aosp_15_r20/external/ComputeLibrary/src/core/CL/cl_kernels/common/gemm.cl (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1/*
2 * Copyright (c) 2017-2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "gemm_helpers.h"
25#include "repeat.h"
26
27#if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE)
28
29#define CONCAT(a, b) a##b
30
31#define ARM_DOT1(a, b, c) \
32    ({                    \
33        c = fma(a, b, c); \
34    })
35#define ARM_DOT2(a, b, c)       \
36    ({                          \
37        c = fma(a.s0, b.s0, c); \
38        c = fma(a.s1, b.s1, c); \
39    })
40#define ARM_DOT3(a, b, c)           \
41    ({                              \
42        ARM_DOT2(a, b, c);          \
43        c = fma((a.s2), (b.s2), c); \
44    })
45#define ARM_DOT4(a, b, c)           \
46    ({                              \
47        ARM_DOT3(a, b, c);          \
48        c = fma((a.s3), (b.s3), c); \
49    })
50#define ARM_DOT8(a, b, c)            \
51    ({                               \
52        ARM_DOT4((a.lo), (b.lo), c); \
53        ARM_DOT4((a.hi), (b.hi), c); \
54    })
55#define ARM_DOT16(a, b, c)           \
56    ({                               \
57        ARM_DOT8((a.lo), (b.lo), c); \
58        ARM_DOT8((a.hi), (b.hi), c); \
59    })
60
61#if N0 == 2
62#define ARM_DOT_K0XN0(k0, a, b, c) \
63    ({                             \
64        CONCAT(ARM_DOT, k0)        \
65        ((a), (b##0), (c.s0));     \
66        CONCAT(ARM_DOT, k0)        \
67        ((a), (b##1), (c.s1));     \
68    })
69#elif N0 == 3 // N0 == 3
70#define ARM_DOT_K0XN0(k0, a, b, c) \
71    ({                             \
72        CONCAT(ARM_DOT, k0)        \
73        ((a), (b##0), (c.s0));     \
74        CONCAT(ARM_DOT, k0)        \
75        ((a), (b##1), (c.s1));     \
76        CONCAT(ARM_DOT, k0)        \
77        ((a), (b##2), (c.s2));     \
78    })
79#elif N0 == 4 // N0 == 4
80#define ARM_DOT_K0XN0(k0, a, b, c) \
81    ({                             \
82        CONCAT(ARM_DOT, k0)        \
83        ((a), (b##0), (c.s0));     \
84        CONCAT(ARM_DOT, k0)        \
85        ((a), (b##1), (c.s1));     \
86        CONCAT(ARM_DOT, k0)        \
87        ((a), (b##2), (c.s2));     \
88        CONCAT(ARM_DOT, k0)        \
89        ((a), (b##3), (c.s3));     \
90    })
91#elif N0 == 8 // N0 == 8
92#define ARM_DOT_K0XN0(k0, a, b, c) \
93    ({                             \
94        CONCAT(ARM_DOT, k0)        \
95        ((a), (b##0), (c.s0));     \
96        CONCAT(ARM_DOT, k0)        \
97        ((a), (b##1), (c.s1));     \
98        CONCAT(ARM_DOT, k0)        \
99        ((a), (b##2), (c.s2));     \
100        CONCAT(ARM_DOT, k0)        \
101        ((a), (b##3), (c.s3));     \
102        CONCAT(ARM_DOT, k0)        \
103        ((a), (b##4), (c.s4));     \
104        CONCAT(ARM_DOT, k0)        \
105        ((a), (b##5), (c.s5));     \
106        CONCAT(ARM_DOT, k0)        \
107        ((a), (b##6), (c.s6));     \
108        CONCAT(ARM_DOT, k0)        \
109        ((a), (b##7), (c.s7));     \
110    })
111#elif N0 == 16 // N0 == 16
112#define ARM_DOT_K0XN0(k0, a, b, c) \
113    ({                             \
114        CONCAT(ARM_DOT, k0)        \
115        ((a), (b##0), (c.s0));     \
116        CONCAT(ARM_DOT, k0)        \
117        ((a), (b##1), (c.s1));     \
118        CONCAT(ARM_DOT, k0)        \
119        ((a), (b##2), (c.s2));     \
120        CONCAT(ARM_DOT, k0)        \
121        ((a), (b##3), (c.s3));     \
122        CONCAT(ARM_DOT, k0)        \
123        ((a), (b##4), (c.s4));     \
124        CONCAT(ARM_DOT, k0)        \
125        ((a), (b##5), (c.s5));     \
126        CONCAT(ARM_DOT, k0)        \
127        ((a), (b##6), (c.s6));     \
128        CONCAT(ARM_DOT, k0)        \
129        ((a), (b##7), (c.s7));     \
130        CONCAT(ARM_DOT, k0)        \
131        ((a), (b##8), (c.s8));     \
132        CONCAT(ARM_DOT, k0)        \
133        ((a), (b##9), (c.s9));     \
134        CONCAT(ARM_DOT, k0)        \
135        ((a), (b##A), (c.sA));     \
136        CONCAT(ARM_DOT, k0)        \
137        ((a), (b##B), (c.sB));     \
138        CONCAT(ARM_DOT, k0)        \
139        ((a), (b##C), (c.sC));     \
140        CONCAT(ARM_DOT, k0)        \
141        ((a), (b##D), (c.sD));     \
142        CONCAT(ARM_DOT, k0)        \
143        ((a), (b##E), (c.sE));     \
144        CONCAT(ARM_DOT, k0)        \
145        ((a), (b##F), (c.sF));     \
146    })
147#else // N0 not supported
148#error "N0 value not supported"
149#endif // N0 conditions
150
151#if defined(GEMM_MM_RESHAPED_ONLY_RHS_T)
152/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
153 *  The LHS matrix is NOT reshaped
154 *  The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed
155 * @note This kernel is duplicated in /experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped_only_rhs.cl
156 *
157 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
158 * @note The GEMM's dimensions (M,N and K) must be passed at runtime as kernel parameters.
159 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
160 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
161 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
162 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
163 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
164 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
165 * @note Only the following configurations of M0, N0 and K0 are currently supported:
166 *  - M0 = 1, 2, 3, 4, 5, 6, 7, 8
167 *  - N0 = 2, 3, 4, 8, 16
168 *  - K0 = 2, 3, 4, 8, 16
169 *  - H0 >= 1
170 *
171 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
172 *       The activation function is performed after the bias addition
173 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
174 *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
175 *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
176 *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
177 *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
178 *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
179 *
180 * @param[in]  lhs_ptr                            Pointer to the LHS matrix. Supported data type: F16/F32
181 * @param[in]  lhs_stride_x                       Stride of the LHS matrix in X dimension (in bytes)
182 * @param[in]  lhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
183 * @param[in]  lhs_stride_y                       Stride of the LHS matrix in Y dimension (in bytes)
184 * @param[in]  lhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
185 * @param[in]  lhs_offset_first_element_in_bytes  The offset of the first element in the LHS matrix
186 * @param[in]  rhs_ptr                            Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
187 * @param[in]  rhs_stride_x                       Stride of the RHS reshaped matrix in X dimension (in bytes)
188 * @param[in]  rhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
189 * @param[in]  rhs_stride_y                       Stride of the RHS reshaped matrix in Y dimension (in bytes)
190 * @param[in]  rhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
191 * @param[in]  rhs_offset_first_element_in_bytes  The offset of the first element in the RHS reshaped matrix
192 * @param[in]  bias_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
193 * @param[in]  bias_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
194 * @param[in]  bias_step_x                        (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
195 * @param[in]  bias_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
196 * @param[in]  bias_step_y                        (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
197 * @param[in]  bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
198 * @param[out] dst_ptr                            Pointer to the destination matrix Supported data type: same as @p lhs_ptr
199 * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
200 * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
201 * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
202 * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
203 * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
204 * @param[in]  lhs_stride_z                       Stride of the LHS matrix in Z dimension (in bytes)
205 * @param[in]  rhs_stride_z                       Stride of the RHS reshaped matrix in Z dimension (in bytes)
206 * @param[in]  bias_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
207 * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
208 * @param[in]  lhs_cross_plane_pad                (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
209 * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
210 * @param[in]  M                                  Number of rows in LHS matrix not reshaped.
211 * @param[in]  N                                  Number of columns in RHS matrix not reshaped.
212 * @param[in]  K                                  Number of columns in LHS matrix and rows in RHS matrix not reshaped.
213 */
214__kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
215                                          IMAGE_DECLARATION(rhs),
216#if defined(BETA)
217                                          IMAGE_DECLARATION(bias),
218#endif // defined(BETA)
219                                          IMAGE_DECLARATION(dst),
220                                          uint lhs_stride_z,
221                                          uint rhs_stride_z,
222#if defined(BETA)
223                                          uint bias_stride_z,
224#endif //defined(BETA)
225                                          uint dst_stride_z
226#if defined(REINTERPRET_INPUT_AS_3D)
227                                          ,
228                                          uint lhs_cross_plane_pad
229#endif // REINTERPRET_INPUT_AS_3D
230#if defined(REINTERPRET_OUTPUT_AS_3D)
231                                          ,
232                                          uint dst_cross_plane_pad
233#endif // REINTERPRET_OUTPUT_AS_3D
234                                          ,
235                                          const int M,
236                                          const int N,
237                                          const int K)
238{
239    // Block size
240#define RHS_BLOCK_SIZE ((K0) * (N0))
241
242    // RHS offset and step X
243#if defined(RHS_INTERLEAVE)
244#define RHS_OFFSET_X (K0)
245#define RHS_STEP_X ((K0) * (H0))
246#define RHS_STEP_LOOP (1)
247#else // defined(RHS_INTERLEAVE)
248#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
249#define RHS_STEP_X (K0)
250#define RHS_STEP_LOOP (H0)
251#endif // defined(RHS_INTERLEAVE)
252
253    uint x = get_global_id(0);
254    uint y = get_global_id(1);
255    uint z = get_global_id(2);
256
257    const bool cond_y = y == 0;
258    const bool cond_x = ((x + 1) * N0 >= N);
259
260#if defined(DUMMY_WORK_ITEMS)
261    if((x * N0 >= N) || (y * M0 >= M))
262    {
263        return;
264    }
265#endif // defined(DUMMY_WORK_ITEMS)
266
267    // Compute LHS matrix address
268    uint lhs_offset = lhs_offset_first_element_in_bytes + COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * (uint)lhs_stride_y;
269
270    // Compute RHS reshaped matrix address
271    uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
272
273#if defined(MATRIX_B_DEPTH)
274    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
275    rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
276#else  // defined(MATRIX_B_DEPTH)
277    rhs_offset += z * rhs_stride_z;
278#endif // defined(MATRIX_B_DEPTH)
279
280    REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
281    REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
282
283#if defined(REINTERPRET_INPUT_AS_3D)
284    // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
285    CALCULATE_Z_OFFSET(M0, uint, zlhs, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
286
287    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
288    // multiply lhs_stride_z by DEPTH_GEMM3D
289    lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
290
291#else // defined(REINTERPRET_INPUT_AS_3D)
292
293    // Add offset for batched GEMM
294    lhs_offset += z * lhs_stride_z;
295
296#endif // defined(REINTERPRET_INPUT_AS_3D)
297
298    // Initialize the accumulators
299    REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0)    c0=0,c1=0,c2=0,... c(M0-1)=0;
300
301    int i = 0;
302    for(; i <= (K - K0); i += K0)
303    {
304        // Supported cases (M0, K0):
305        // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
306        // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
307        // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
308        // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
309        // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
310        // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
311        // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
312        // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
313        // Load values from LHS matrix
314        LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
315
316        // Load values from RHS reshaped matrix
317        LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
318
319        // Accumulate
320        ARM_DOT_K0XN0(K0, a0, b, c0);
321#if M0 > 1
322        ARM_DOT_K0XN0(K0, a1, b, c1);
323#endif // M0 > 1
324#if M0 > 2
325        ARM_DOT_K0XN0(K0, a2, b, c2);
326#endif // M0 > 2
327#if M0 > 3
328        ARM_DOT_K0XN0(K0, a3, b, c3);
329#endif // M0 > 3
330#if M0 > 4
331        ARM_DOT_K0XN0(K0, a4, b, c4);
332#endif // M0 > 4
333#if M0 > 5
334        ARM_DOT_K0XN0(K0, a5, b, c5);
335#endif // M0 > 5
336#if M0 > 6
337        ARM_DOT_K0XN0(K0, a6, b, c6);
338#endif // M0 > 6
339#if M0 > 7
340        ARM_DOT_K0XN0(K0, a7, b, c7);
341#endif // M0 > 7
342
343        lhs_offset += K0 * sizeof(DATA_TYPE);
344        rhs_offset += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
345    }
346
347    // Left-over accumulations
348    for(; i < K; ++i)
349    {
350        // Load values from LHS matrix
351        LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
352
353        // Load values from RHS reshaped matrix
354        LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
355
356        // Accumulate
357        ARM_DOT_K0XN0(1, a0, b, c0);
358#if M0 > 1
359        ARM_DOT_K0XN0(1, a1, b, c1);
360#endif // M0 > 1
361#if M0 > 2
362        ARM_DOT_K0XN0(1, a2, b, c2);
363#endif // M0 > 2
364#if M0 > 3
365        ARM_DOT_K0XN0(1, a3, b, c3);
366#endif // M0 > 3
367#if M0 > 4
368        ARM_DOT_K0XN0(1, a4, b, c4);
369#endif // M0 > 4
370#if M0 > 5
371        ARM_DOT_K0XN0(1, a5, b, c5);
372#endif // M0 > 5
373#if M0 > 6
374        ARM_DOT_K0XN0(1, a6, b, c6);
375#endif // M0 > 6
376#if M0 > 7
377        ARM_DOT_K0XN0(1, a7, b, c7);
378#endif // M0 > 7
379
380        lhs_offset += sizeof(DATA_TYPE);
381        rhs_offset += sizeof(DATA_TYPE);
382    }
383
384    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * dst_stride_y);
385
386    REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
387
388#if defined(REINTERPRET_OUTPUT_AS_3D)
389
390    // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
391    CALCULATE_Z_OFFSET(M0, uint, zout, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
392
393    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
394    // multiply dst_stride_z by DEPTH_GEMM3D
395    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
396
397#else // defined(REINTERPRET_OUTPUT_AS_3D)
398
399    // Add offset for batched GEMM
400    dst_addr += z * dst_stride_z;
401
402#endif // defined(REINTERPRET_OUTPUT_AS_3D)
403
404    // Multiply by the weight of matrix-matrix product and store the result
405#if defined(ALPHA)
406    SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
407#endif // defined(ALPHA)
408
409    // Add beta*bias
410#if defined(BETA)
411#if defined(BROADCAST_BIAS)
412    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
413
414    LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, 1, PARTIAL_STORE_N0, false, cond_x);
415
416#ifndef UNIT_BETA
417    SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
418#endif // UNIT_BIAS
419
420    // c = c + bias[broadcasted]
421    ADD_BLOCK_BROADCAST(M0, c, bias0);
422
423#else // defined(BROADCAST_BIAS)
424    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * bias_stride_y) + z * bias_stride_z;
425
426    LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
427
428#ifndef UNIT_BETA
429    SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
430#endif // UNIT_BIAS
431
432    // c = c + bias
433    ADD_BLOCK(M0, c, bias);
434
435#endif // defined(BROADCAST_BIAS)
436#endif // defined(BETA)
437
438#if defined(ACTIVATION_TYPE)
439    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, N0, c, A_VAL, B_VAL);
440#endif // defined(ACTIVATION_TYPE)
441
442    // Store output block
443    STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
444
445#undef RHS_BLOCK_SIZE
446#undef RHS_OFFSET_X
447#undef RHS_STEP_X
448#undef RHS_STEP_LOOP
449}
450#endif // defined(GEMM_MM_RESHAPED_ONLY_RHS_T)
451
452#if defined(OPENCL_IMAGE_SUPPORT) && defined(GEMM_MM_RESHAPED_ONLY_RHS_T_TEXTURE)
453/** This OpenCL kernel computes the matrix multiplication between 2 matrices. The RHS matrix is stored in OpenCL image
454 *  The LHS matrix is NOT reshaped
455 *  The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed
456 * @note This kernel is duplicated in /experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped_only_rhs.cl
457 *
458 * @note -DOPENCL_IMAGE_SUPPORT must be passed at compile time in order to compile this OpenCL kernel
459 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
460 * @note The GEMM's dimensions (M,N and K) must be passed at runtime as kernel parameters.
461 * @note The height of the RHS matrix, defined before creating the OpenCL image object from the OpenCL buffer, should be passed at compile time using -DRHS_HEIGHT=<value> (e.g. -DRHS_HEIGHT=32)
462 *       Since we cannot create a 3d image from a buffer, the third dimension could be collapsed with the second dimension so RHS_HEIGHT
463 *       could be different from the value returned by get_image_height(rhs_img).
464 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
465 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
466 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
467 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
468 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
469 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
470 * @note Only the following configurations of M0, N0 and K0 are currently supported:
471 *  - M0 = 1, 2, 3, 4, 5, 6, 7, 8
472 *  - N0 = 4, 8, 16
473 *  - K0 = 4, 8, 16
474 *  - H0 >= 1
475 *
476 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
477 *       The activation function is performed after the bias addition
478 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
479 *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
480 *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
481 *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
482 *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
483 *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
484 *
485 * @param[in]  lhs_ptr                            Pointer to the LHS matrix. Supported data type: F32
486 * @param[in]  lhs_stride_x                       Stride of the LHS matrix in X dimension (in bytes)
487 * @param[in]  lhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
488 * @param[in]  lhs_stride_y                       Stride of the LHS matrix in Y dimension (in bytes)
489 * @param[in]  lhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
490 * @param[in]  lhs_offset_first_element_in_bytes  The offset of the first element in the LHS matrix
491 * @param[in]  rhs_img                            The RHS reshaped matrix as OpenCL image object. Supported data type: same as @p lhs_ptr
492 * @param[in]  bias_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
493 * @param[in]  bias_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
494 * @param[in]  bias_step_x                        (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
495 * @param[in]  bias_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
496 * @param[in]  bias_step_y                        (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
497 * @param[in]  bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
498 * @param[out] dst_ptr                            Pointer to the destination matrix Supported data type: same as @p lhs_ptr
499 * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
500 * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
501 * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
502 * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
503 * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
504 * @param[in]  lhs_stride_z                       Stride of the LHS matrix in Z dimension (in bytes)
505 * @param[in]  rhs_stride_z                       Stride of the RHS reshaped matrix in Z dimension (in bytes)
506 * @param[in]  bias_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
507 * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
508 * @param[in]  lhs_cross_plane_pad                (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
509 * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
510 * @param[in]  M                                  Number of rows in LHS matrix not reshaped.
511 * @param[in]  N                                  Number of columns in RHS matrix not reshaped.
512 * @param[in]  K                                  Number of columns in LHS matrix and rows in RHS matrix not reshaped.
513 */
514__kernel void gemm_mm_reshaped_only_rhs_t_texture(IMAGE_DECLARATION(lhs),
515                                                  __read_only image2d_t rhs_img,
516#if defined(BETA)
517                                                  IMAGE_DECLARATION(bias),
518#endif // defined(BETA)
519                                                  IMAGE_DECLARATION(dst),
520                                                  uint lhs_stride_z,
521                                                  uint rhs_stride_z,
522#if defined(BETA)
523                                                  uint bias_stride_z,
524#endif //defined(BETA)
525                                                  uint dst_stride_z
526#if defined(REINTERPRET_INPUT_AS_3D)
527                                                  ,
528                                                  uint lhs_cross_plane_pad
529#endif // REINTERPRET_INPUT_AS_3D
530#if defined(REINTERPRET_OUTPUT_AS_3D)
531                                                  ,
532                                                  uint dst_cross_plane_pad
533#endif // REINTERPRET_OUTPUT_AS_3D
534                                                  ,
535                                                  const int M,
536                                                  const int N,
537                                                  const int K)
538{
539    // Pixel unit
540#define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(K0)
541
542    const uint LEFTOVER_K = K % K0;
543
544    // Block size
545#define RHS_BLOCK_SIZE (PIXEL_UNIT * (N0))
546
547    // RHS offset and step X
548#if defined(RHS_INTERLEAVE)
549#define RHS_OFFSET_X (PIXEL_UNIT)
550#define RHS_STEP_X (PIXEL_UNIT * (H0))
551#define RHS_STEP_LOOP (1)
552#else // defined(RHS_INTERLEAVE)
553#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
554#define RHS_STEP_X PIXEL_UNIT
555#define RHS_STEP_LOOP (H0)
556#endif // defined(RHS_INTERLEAVE)
557
558    uint x = get_global_id(0);
559    uint y = get_global_id(1);
560    uint z = get_global_id(2);
561
562    const bool cond_y = y == 0;
563    const bool cond_x = ((x + 1) * N0 >= N);
564
565#if defined(DUMMY_WORK_ITEMS)
566    if((x * N0 >= N) || (y * M0 >= M))
567    {
568        return;
569    }
570#endif // defined(DUMMY_WORK_ITEMS)
571
572    // Compute LHS matrix address
573    uint lhs_offset = lhs_offset_first_element_in_bytes + COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * (uint)lhs_stride_y;
574
575#if defined(MATRIX_B_DEPTH)
576    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
577    const uint z_rhs = (get_global_id(2) % MATRIX_B_DEPTH);
578#else  // defined(MATRIX_B_DEPTH)
579    const uint z_rhs = get_global_id(2);
580#endif // defined(MATRIX_B_DEPTH)
581
582    // Compute RHS matrix coordinates
583    uint       x_rhs = (get_global_id(0) % H0) * (uint)RHS_OFFSET_X;
584    const uint y_rhs = (get_global_id(0) / (uint)H0) + z_rhs * RHS_HEIGHT;
585
586    REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0);
587    REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
588
589#if defined(REINTERPRET_INPUT_AS_3D)
590    // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
591    CALCULATE_Z_OFFSET(M0, uint, zlhs, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
592
593    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
594    // multiply lhs_stride_z by DEPTH_GEMM3D
595    lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
596
597#else // defined(REINTERPRET_INPUT_AS_3D)
598
599    // Add offset for batched GEMM
600    lhs_offset += z * lhs_stride_z;
601
602#endif // defined(REINTERPRET_INPUT_AS_3D)
603
604    // Initialize the accumulators
605    REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0);
606
607    int i = 0;
608    for(; i <= (K - K0); i += K0)
609    {
610        // Load values from LHS matrix
611        LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
612
613        // Load values from RHS matrix stored in a cl_image
614        REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), b, 0);
615        LOAD_TEXTURE2D(N0, PIXEL_UNIT, DATA_TYPE, b, rhs_img, x_rhs, y_rhs, RHS_STEP_X, 0);
616
617        // Accumulate
618        ARM_DOT_K0XN0(K0, a0, b, c0);
619#if M0 > 1
620        ARM_DOT_K0XN0(K0, a1, b, c1);
621#endif // M0 > 1
622#if M0 > 2
623        ARM_DOT_K0XN0(K0, a2, b, c2);
624#endif // M0 > 2
625#if M0 > 3
626        ARM_DOT_K0XN0(K0, a3, b, c3);
627#endif // M0 > 3
628#if M0 > 4
629        ARM_DOT_K0XN0(K0, a4, b, c4);
630#endif // M0 > 4
631#if M0 > 5
632        ARM_DOT_K0XN0(K0, a5, b, c5);
633#endif // M0 > 5
634#if M0 > 6
635        ARM_DOT_K0XN0(K0, a6, b, c6);
636#endif // M0 > 6
637#if M0 > 7
638        ARM_DOT_K0XN0(K0, a7, b, c7);
639#endif // M0 > 7
640
641        lhs_offset += K0 * sizeof(DATA_TYPE);
642        x_rhs += N0 * RHS_STEP_X * RHS_STEP_LOOP;
643    }
644
645    if(LEFTOVER_K != 0)
646    {
647        // Note: We cannot read out-of-bound elements from the RHS matrix because
648        // the RHS width is always multiple of K0. This is not be true for the LHS matrix
649        // Left-over accumulations for LHS matrix
650
651        union UNION_VEC_TYPE
652        {
653            DATA_TYPE s[K0];
654            VEC_DATA_TYPE(DATA_TYPE, K0)
655            v;
656        };
657
658        union UNION_VEC_TYPE a0 = {.v = 0 };
659#if M0 > 1
660        union UNION_VEC_TYPE a1 = {.v = 0 };
661#endif // M0 > 1
662#if M0 > 2
663        union UNION_VEC_TYPE a2 = {.v = 0 };
664#endif // M0 > 2
665#if M0 > 3
666        union UNION_VEC_TYPE a3 = {.v = 0 };
667#endif // M0 > 3
668#if M0 > 4
669        union UNION_VEC_TYPE a4 = {.v = 0 };
670#endif // M0 > 4
671#if M0 > 5
672        union UNION_VEC_TYPE a5 = {.v = 0 };
673#endif // M0 > 5
674#if M0 > 6
675        union UNION_VEC_TYPE a6 = {.v = 0 };
676#endif // M0 > 6
677#if M0 > 7
678        union UNION_VEC_TYPE a7 = {.v = 0 };
679#endif // M0 > 7
680
681        REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), b, 0);
682
683        // Load from RHS matrix
684        LOAD_TEXTURE2D(N0, PIXEL_UNIT, DATA_TYPE, b, rhs_img, x_rhs, y_rhs, RHS_STEP_X, 0);
685
686        // Load from LHS matrix
687        for(int k = 0; k < LEFTOVER_K; ++k)
688        {
689            a0.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zlhs0);
690#if M0 > 1
691            a1.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zlhs1);
692#endif // M0 > 1
693#if M0 > 2
694            a2.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zlhs2);
695#endif // M0 > 2
696#if M0 > 3
697            a3.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zlhs3);
698#endif // M0 > 3
699#if M0 > 4
700            a4.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zlhs4);
701#endif // M0 > 4
702#if M0 > 5
703            a5.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zlhs5);
704#endif // M0 > 5
705#if M0 > 6
706            a6.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zlhs6);
707#endif // M0 > 6
708#if M0 > 7
709            a7.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zlhs7);
710#endif // M0 > 7
711
712            lhs_offset += sizeof(DATA_TYPE);
713        }
714
715        // Accumulate
716        ARM_DOT_K0XN0(K0, a0.v, b, c0);
717#if M0 > 1
718        ARM_DOT_K0XN0(K0, a1.v, b, c1);
719#endif // M0 > 1
720#if M0 > 2
721        ARM_DOT_K0XN0(K0, a2.v, b, c2);
722#endif // M0 > 2
723#if M0 > 3
724        ARM_DOT_K0XN0(K0, a3.v, b, c3);
725#endif // M0 > 3
726#if M0 > 4
727        ARM_DOT_K0XN0(K0, a4.v, b, c4);
728#endif // M0 > 4
729#if M0 > 5
730        ARM_DOT_K0XN0(K0, a5.v, b, c5);
731#endif // M0 > 5
732#if M0 > 6
733        ARM_DOT_K0XN0(K0, a6.v, b, c6);
734#endif // M0 > 6
735#if M0 > 7
736        ARM_DOT_K0XN0(K0, a7.v, b, c7);
737#endif // M0 > 7
738    }
739
740    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * dst_stride_y);
741
742    REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
743
744#if defined(REINTERPRET_OUTPUT_AS_3D)
745
746    // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
747    CALCULATE_Z_OFFSET(M0, uint, zout, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
748
749    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
750    // multiply dst_stride_z by DEPTH_GEMM3D
751    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
752
753#else // defined(REINTERPRET_OUTPUT_AS_3D)
754
755    // Add offset for batched GEMM
756    dst_addr += z * dst_stride_z;
757
758#endif // defined(REINTERPRET_OUTPUT_AS_3D)
759
760    // Multiply by the weight of matrix-matrix product and store the result
761#if defined(ALPHA)
762    SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
763#endif // defined(ALPHA)
764
765    // Add beta*bias
766#if defined(BETA)
767#if defined(BROADCAST_BIAS)
768    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
769
770    LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, 1, PARTIAL_STORE_N0, false, cond_x);
771
772#ifndef UNIT_BETA
773    SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
774#endif // UNIT_BIAS
775
776    // c = c + bias[broadcasted]
777    ADD_BLOCK_BROADCAST(M0, c, bias0);
778
779#else // defined(BROADCAST_BIAS)
780    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * bias_stride_y) + z * bias_stride_z;
781
782    LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
783
784#ifndef UNIT_BETA
785    SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
786#endif // UNIT_BIAS
787
788    // c = c + bias
789    ADD_BLOCK(M0, c, bias);
790
791#endif // defined(BROADCAST_BIAS)
792#endif // defined(BETA)
793
794#if defined(ACTIVATION_TYPE)
795    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, N0, c, A_VAL, B_VAL);
796#endif // defined(ACTIVATION_TYPE)
797
798    // Store output block
799    STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
800
801#undef RHS_BLOCK_SIZE
802#undef RHS_OFFSET_X
803#undef RHS_STEP_X
804#undef RHS_STEP_LOOP
805#undef PIXEL_UNIT
806}
807#endif // defined(OPENCL_IMAGE_SUPPORT) && defined(GEMM_MM_RESHAPED_ONLY_RHS_T_TEXTURE)
808
809#define VFMA(a, b, c)     \
810    ({                    \
811        c = fma(a, b, c); \
812    })
813
814#if M0 == 1
815#define VFMA_M0xN0(i, a, b, c)                                        \
816    ({                                                                \
817        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
818    })
819#elif M0 == 2 // M0 == 2
820#define VFMA_M0xN0(i, a, b, c)                                        \
821    ({                                                                \
822        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
823        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
824    })
825#elif M0 == 3 // M0 == 3
826#define VFMA_M0xN0(i, a, b, c)                                        \
827    ({                                                                \
828        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
829        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
830        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
831    })
832#elif M0 == 4 // M0 == 4
833#define VFMA_M0xN0(i, a, b, c)                                        \
834    ({                                                                \
835        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
836        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
837        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
838        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
839    })
840#elif M0 == 5 // M0 == 5
841#define VFMA_M0xN0(i, a, b, c)                                        \
842    ({                                                                \
843        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
844        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
845        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
846        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
847        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
848    })
849#elif M0 == 6 // M0 == 6
850#define VFMA_M0xN0(i, a, b, c)                                        \
851    ({                                                                \
852        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
853        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
854        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
855        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
856        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
857        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
858    })
859#elif M0 == 7 // M0 == 7
860#define VFMA_M0xN0(i, a, b, c)                                        \
861    ({                                                                \
862        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
863        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
864        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
865        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
866        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
867        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
868        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
869    })
870#elif M0 == 8 // M0 == 8
871#define VFMA_M0xN0(i, a, b, c)                                        \
872    ({                                                                \
873        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
874        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
875        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
876        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
877        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
878        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
879        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
880        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
881    })
882#else // M0 not supported
883#error "M0 not supported"
884#endif // M0 not supported
885
886#if defined(GEMM_MM_RESHAPED_ONLY_RHS_NT)
887/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
888 *  The LHS matrix is NOT reshaped
889 *  The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is NOT transposed
890 * @note This kernel is duplicated in /experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped_only_rhs.cl
891 *
892 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
893 * @note The GEMM's dimensions (M,N and K) must be passed at runtime as kernel parameters.
894 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
895 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
896 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
897 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
898 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
899 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
900 * @note Only the following configurations of M0, N0 and K0 are currently supported:
901 *  - M0 = 1, 2, 3, 4, 5, 6, 7, 8
902 *  - N0 = 2, 3, 4, 8, 16
903 *  - K0 = 2, 3, 4, 8, 16
904 *  - H0 >= 1
905 *
906 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
907 *       The activation function is performed after the bias addition
908 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
909 *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
910 *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
911 *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
912 *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
913 *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
914 *
915 * @param[in]  lhs_ptr                            Pointer to the LHS matrix. Supported data type: F16/F32
916 * @param[in]  lhs_stride_x                       Stride of the LHS matrix in X dimension (in bytes)
917 * @param[in]  lhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
918 * @param[in]  lhs_stride_y                       Stride of the LHS matrix in Y dimension (in bytes)
919 * @param[in]  lhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
920 * @param[in]  lhs_offset_first_element_in_bytes  The offset of the first element in the LHS matrix
921 * @param[in]  rhs_ptr                            Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
922 * @param[in]  rhs_stride_x                       Stride of the RHS reshaped matrix in X dimension (in bytes)
923 * @param[in]  rhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
924 * @param[in]  rhs_stride_y                       Stride of the RHS reshaped matrix in Y dimension (in bytes)
925 * @param[in]  rhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
926 * @param[in]  rhs_offset_first_element_in_bytes  The offset of the first element in the RHS reshaped matrix
927 * @param[in]  bias_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
928 * @param[in]  bias_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
929 * @param[in]  bias_step_x                        (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
930 * @param[in]  bias_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
931 * @param[in]  bias_step_y                        (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
932 * @param[in]  bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
933 * @param[out] dst_ptr                            Pointer to the destination matrix Supported data type: same as @p lhs_ptr
934 * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
935 * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
936 * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
937 * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
938 * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
939 * @param[in]  lhs_stride_z                       Stride of the LHS matrix in Z dimension (in bytes)
940 * @param[in]  rhs_stride_z                       Stride of the RHS reshaped matrix in Z dimension (in bytes)
941 * @param[in]  bias_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
942 * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
943 * @param[in]  lhs_cross_plane_pad                (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
944 * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
945 * @param[in]  M                                  Number of rows in LHS matrix not reshaped.
946 * @param[in]  N                                  Number of columns in RHS matrix not reshaped.
947 * @param[in]  K                                  Number of columns in LHS matrix and rows in RHS matrix not reshaped.
948 */
949__kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
950                                           IMAGE_DECLARATION(rhs),
951#if defined(BETA)
952                                           IMAGE_DECLARATION(bias),
953#endif // defined(BETA)
954                                           IMAGE_DECLARATION(dst),
955                                           uint lhs_stride_z,
956                                           uint rhs_stride_z,
957#if defined(BETA)
958                                           uint bias_stride_z,
959#endif //defined(BETA)
960                                           uint dst_stride_z
961#if defined(REINTERPRET_INPUT_AS_3D)
962                                           ,
963                                           uint lhs_cross_plane_pad
964#endif // REINTERPRET_INPUT_AS_3D
965#if defined(REINTERPRET_OUTPUT_AS_3D)
966                                           ,
967                                           uint dst_cross_plane_pad
968#endif // REINTERPRET_OUTPUT_AS_3D
969                                           ,
970                                           const int M,
971                                           const int N,
972                                           const int K)
973{
974    // Block size
975#define RHS_BLOCK_SIZE ((K0) * (N0))
976
977    // RHS offset and step X
978#if defined(RHS_INTERLEAVE)
979#define RHS_OFFSET_X (N0)
980#define RHS_STEP_X ((N0) * (H0))
981#define RHS_STEP_LOOP (1)
982#else // defined(RHS_INTERLEAVE)
983#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
984#define RHS_STEP_X (N0)
985#define RHS_STEP_LOOP (H0)
986#endif // defined(RHS_INTERLEAVE)
987
988    uint x = get_global_id(0);
989    uint y = get_global_id(1);
990    uint z = get_global_id(2);
991
992    const bool cond_y = y == 0;
993    const bool cond_x = ((x + 1) * N0 >= N);
994
995#if defined(DUMMY_WORK_ITEMS)
996    if((x * N0 >= N) || (y * M0 >= M))
997    {
998        return;
999    }
1000#endif // defined(DUMMY_WORK_ITEMS)
1001
1002    // Compute LHS matrix address
1003    uint lhs_offset = lhs_offset_first_element_in_bytes + COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * (uint)lhs_stride_y;
1004
1005    // Compute RHS reshaped matrix address
1006    uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1007
1008#if defined(MATRIX_B_DEPTH)
1009    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1010    rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1011#else  // defined(MATRIX_B_DEPTH)
1012    rhs_offset += z * rhs_stride_z;
1013#endif // defined(MATRIX_B_DEPTH)
1014
1015    REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0);   //uint zin0=0,zin1=0,zin2=0,... zin7=0;
1016    REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0); //uint zero0=0,zero1=0,zero2=0,... zero7=0;
1017
1018#if defined(REINTERPRET_INPUT_AS_3D)
1019
1020    // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1021    CALCULATE_Z_OFFSET(M0, uint, zin, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
1022
1023    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1024    // multiply lhs_stride_z by DEPTH_GEMM3D
1025    lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1026
1027#else // defined(REINTERPRET_INPUT_AS_3D)
1028
1029    // Add offset for batched GEMM
1030    lhs_offset += z * lhs_stride_z;
1031
1032#endif // defined(REINTERPRET_INPUT_AS_3D)
1033
1034    // Initialize the accumulators
1035    REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0)    c0=0,c1=0,c2=0,... c(N0-1)=0;
1036
1037    int i = 0;
1038    for(; i <= (K - K0); i += K0)
1039    {
1040        // Supported cases (M0, K0):
1041        // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1042        // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1043        // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1044        // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1045        // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1046        // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1047        // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1048        // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1049        // Load values from LHS matrix
1050        LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin);
1051
1052        VEC_DATA_TYPE(DATA_TYPE, N0)
1053        b0;
1054
1055        b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * RHS_STEP_X * sizeof(DATA_TYPE)));
1056        VFMA_M0xN0(0, a, b0, c);
1057        b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 1 * RHS_STEP_X * sizeof(DATA_TYPE)));
1058        VFMA_M0xN0(1, a, b0, c);
1059#if K0 > 2
1060        b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 2 * RHS_STEP_X * sizeof(DATA_TYPE)));
1061        VFMA_M0xN0(2, a, b0, c);
1062#endif // K0 > 2
1063#if K0 > 3
1064        b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 3 * RHS_STEP_X * sizeof(DATA_TYPE)));
1065        VFMA_M0xN0(3, a, b0, c);
1066#endif // K0 > 3
1067#if K0 > 4
1068        b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 4 * RHS_STEP_X * sizeof(DATA_TYPE)));
1069        VFMA_M0xN0(4, a, b0, c);
1070        b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 5 * RHS_STEP_X * sizeof(DATA_TYPE)));
1071        VFMA_M0xN0(5, a, b0, c);
1072        b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 6 * RHS_STEP_X * sizeof(DATA_TYPE)));
1073        VFMA_M0xN0(6, a, b0, c);
1074        b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 7 * RHS_STEP_X * sizeof(DATA_TYPE)));
1075        VFMA_M0xN0(7, a, b0, c);
1076#endif // K0 > 4
1077#if K0 > 8
1078        b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 8 * RHS_STEP_X * sizeof(DATA_TYPE)));
1079        VFMA_M0xN0(8, a, b0, c);
1080        b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 9 * RHS_STEP_X * sizeof(DATA_TYPE)));
1081        VFMA_M0xN0(9, a, b0, c);
1082        b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 10 * RHS_STEP_X * sizeof(DATA_TYPE)));
1083        VFMA_M0xN0(A, a, b0, c);
1084        b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 11 * RHS_STEP_X * sizeof(DATA_TYPE)));
1085        VFMA_M0xN0(B, a, b0, c);
1086        b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 12 * RHS_STEP_X * sizeof(DATA_TYPE)));
1087        VFMA_M0xN0(C, a, b0, c);
1088        b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 13 * RHS_STEP_X * sizeof(DATA_TYPE)));
1089        VFMA_M0xN0(D, a, b0, c);
1090        b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 14 * RHS_STEP_X * sizeof(DATA_TYPE)));
1091        VFMA_M0xN0(E, a, b0, c);
1092        b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 15 * RHS_STEP_X * sizeof(DATA_TYPE)));
1093        VFMA_M0xN0(F, a, b0, c);
1094#endif // K0 > 8
1095
1096        lhs_offset += K0 * sizeof(DATA_TYPE);
1097        rhs_offset += K0 * RHS_STEP_X * RHS_STEP_LOOP * sizeof(DATA_TYPE);
1098    }
1099
1100    // Left-over accumulations
1101    for(; i < K; ++i)
1102    {
1103        // Load values from LHS matrix
1104        VEC_DATA_TYPE(DATA_TYPE, 2)
1105        a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1106#if M0 > 1
1107        VEC_DATA_TYPE(DATA_TYPE, 2)
1108        a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1109#endif // M0 > 1
1110#if M0 > 2
1111        VEC_DATA_TYPE(DATA_TYPE, 2)
1112        a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1113#endif // M0 > 2
1114#if M0 > 3
1115        VEC_DATA_TYPE(DATA_TYPE, 2)
1116        a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1117#endif // M0 > 3
1118#if M0 > 4
1119        VEC_DATA_TYPE(DATA_TYPE, 2)
1120        a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1121#endif // M0 > 4
1122#if M0 > 5
1123        VEC_DATA_TYPE(DATA_TYPE, 2)
1124        a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1125#endif // M0 > 5
1126#if M0 > 6
1127        VEC_DATA_TYPE(DATA_TYPE, 2)
1128        a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1129#endif // M0 > 6
1130#if M0 > 7
1131        VEC_DATA_TYPE(DATA_TYPE, 2)
1132        a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
1133#endif // M0 > 7
1134
1135        VEC_DATA_TYPE(DATA_TYPE, N0)
1136        b0;
1137
1138        b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * RHS_STEP_X * sizeof(DATA_TYPE)));
1139        VFMA_M0xN0(0, a, b0, c);
1140
1141        lhs_offset += sizeof(DATA_TYPE);
1142        rhs_offset += RHS_STEP_X * sizeof(DATA_TYPE);
1143    }
1144
1145    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * dst_stride_y);
1146
1147    REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1148
1149#if defined(REINTERPRET_OUTPUT_AS_3D)
1150    // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1151    CALCULATE_Z_OFFSET(M0, uint, zout, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
1152
1153    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1154    // multiply dst_stride_z by DEPTH_GEMM3D
1155    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1156
1157#else // defined(REINTERPRET_OUTPUT_AS_3D)
1158
1159    // Add offset for batched GEMM
1160    dst_addr += z * dst_stride_z;
1161
1162#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1163
1164    // Multiply by the weight of matrix-matrix product and store the result
1165#if defined(ALPHA)
1166    SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
1167#endif // defined(ALPHA)
1168
1169    // Add beta*bias
1170#if defined(BETA)
1171#if defined(BROADCAST_BIAS)
1172    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1173
1174    LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, 1, PARTIAL_STORE_N0, false, cond_x);
1175
1176#ifndef UNIT_BETA
1177    SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1178#endif // UNIT_BIAS
1179
1180    // c = c + bias[broadcasted]
1181    ADD_BLOCK_BROADCAST(M0, c, bias0);
1182
1183#else // defined(BROADCAST_BIAS)
1184    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * bias_stride_y) + z * bias_stride_z;
1185
1186    LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
1187
1188#ifndef UNIT_BETA
1189    SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1190#endif // UNIT_BIAS
1191
1192    // c = c + bias
1193    ADD_BLOCK(M0, c, bias);
1194
1195#endif // defined(BROADCAST_BIAS)
1196#endif // defined(BETA)
1197
1198#if defined(ACTIVATION_TYPE)
1199    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, N0, c, A_VAL, B_VAL);
1200#endif // defined(ACTIVATION_TYPE)
1201
1202    // Store output block
1203    STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
1204
1205#undef RHS_BLOCK_SIZE
1206#undef RHS_OFFSET_X
1207#undef RHS_STEP_X
1208#undef RHS_STEP_LOOP
1209}
1210#endif // defined(GEMM_MM_RESHAPED_ONLY_RHS_NT)
1211
1212#if defined(OPENCL_IMAGE_SUPPORT) && defined(GEMM_MM_RESHAPED_ONLY_RHS_NT_TEXTURE)
1213/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1214 *  The LHS matrix is NOT reshaped
1215 *  The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is NOT transposed
1216 * @note This kernel is duplicated in /experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped_only_rhs.cl
1217 *
1218 * @note -DOPENCL_IMAGE_SUPPORT must be passed at compile time in order to compile this OpenCL kernel
1219 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
1220 * @note The GEMM's dimensions (M,N and K) must be passed at runtime as kernel parameters.
1221 * @note The height of the RHS matrix, defined before creating the OpenCL image object from the OpenCL buffer, should be passed at compile time using -DRHS_HEIGHT=<value> (e.g. -DRHS_HEIGHT=32)
1222 *       Since we cannot create a 3d image from a buffer, the third dimension could be collapsed with the second dimension so RHS_HEIGHT
1223 *       could be different from the value returned by get_image_height(rhs_img).
1224 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
1225 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
1226 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
1227 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1228 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
1229 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
1230 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1231 *  - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1232 *  - N0 = 4, 8, 16
1233 *  - K0 = 4, 8, 16
1234 *  - H0 >= 1
1235 *
1236 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
1237 *       The activation function is performed after the bias addition
1238 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1239 *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1240 *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1241 *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1242 *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1243 *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1244 *
1245 * @param[in]  lhs_ptr                            Pointer to the LHS matrix. Supported data type: F32
1246 * @param[in]  lhs_stride_x                       Stride of the LHS matrix in X dimension (in bytes)
1247 * @param[in]  lhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
1248 * @param[in]  lhs_stride_y                       Stride of the LHS matrix in Y dimension (in bytes)
1249 * @param[in]  lhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
1250 * @param[in]  lhs_offset_first_element_in_bytes  The offset of the first element in the LHS matrix
1251 * @param[in]  rhs_img                            The RHS reshaped matrix as OpenCL image object. Supported data type: same as @p lhs_ptr
1252 * @param[in]  bias_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1253 * @param[in]  bias_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
1254 * @param[in]  bias_step_x                        (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
1255 * @param[in]  bias_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
1256 * @param[in]  bias_step_y                        (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1257 * @param[in]  bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1258 * @param[out] dst_ptr                            Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1259 * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
1260 * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
1261 * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
1262 * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
1263 * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
1264 * @param[in]  lhs_stride_z                       Stride of the LHS matrix in Z dimension (in bytes)
1265 * @param[in]  rhs_stride_z                       Stride of the RHS reshaped matrix in Z dimension (in bytes)
1266 * @param[in]  bias_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
1267 * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
1268 * @param[in]  lhs_cross_plane_pad                (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1269 * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1270 * @param[in]  M                                  Number of rows in LHS matrix not reshaped.
1271 * @param[in]  N                                  Number of columns in RHS matrix not reshaped.
1272 * @param[in]  K                                  Number of columns in LHS matrix and rows in RHS matrix not reshaped.
1273 */
1274__kernel void gemm_mm_reshaped_only_rhs_nt_texture(IMAGE_DECLARATION(lhs),
1275                                                   __read_only image2d_t rhs_img,
1276#if defined(BETA)
1277                                                   IMAGE_DECLARATION(bias),
1278#endif // defined(BETA)
1279                                                   IMAGE_DECLARATION(dst),
1280                                                   uint lhs_stride_z,
1281                                                   uint rhs_stride_z,
1282#if defined(BETA)
1283                                                   uint bias_stride_z,
1284#endif //defined(BETA)
1285                                                   uint dst_stride_z
1286#if defined(REINTERPRET_INPUT_AS_3D)
1287                                                   ,
1288                                                   uint lhs_cross_plane_pad
1289#endif // REINTERPRET_INPUT_AS_3D
1290#if defined(REINTERPRET_OUTPUT_AS_3D)
1291                                                   ,
1292                                                   uint dst_cross_plane_pad
1293#endif // REINTERPRET_OUTPUT_AS_3D
1294                                                   ,
1295                                                   const int M,
1296                                                   const int N,
1297                                                   const int K)
1298{
1299    // Pixel unit
1300#define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(N0)
1301
1302    // Block size
1303#define RHS_BLOCK_SIZE ((K0) * (PIXEL_UNIT))
1304
1305    // RHS offset and step X
1306#if defined(RHS_INTERLEAVE)
1307#define RHS_OFFSET_X (PIXEL_UNIT)
1308#define RHS_STEP_X ((PIXEL_UNIT) * (H0))
1309#define RHS_STEP_LOOP 1
1310#else // defined(RHS_INTERLEAVE)
1311#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1312#define RHS_STEP_X (PIXEL_UNIT)
1313#define RHS_STEP_LOOP (H0)
1314#endif // defined(RHS_INTERLEAVE)
1315
1316    uint x = get_global_id(0);
1317    uint y = get_global_id(1);
1318    uint z = get_global_id(2);
1319
1320    const bool cond_y = y == 0;
1321    const bool cond_x = ((x + 1) * N0 >= N);
1322
1323#if defined(DUMMY_WORK_ITEMS)
1324    if((x * N0 >= N) || (y * M0 >= M))
1325    {
1326        return;
1327    }
1328#endif // defined(DUMMY_WORK_ITEMS)
1329
1330    // Compute LHS matrix address
1331    uint lhs_offset = lhs_offset_first_element_in_bytes + COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * (uint)lhs_stride_y;
1332
1333#if defined(MATRIX_B_DEPTH)
1334    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1335    const uint z_rhs = (z % MATRIX_B_DEPTH);
1336#else  // defined(MATRIX_B_DEPTH)
1337    const uint z_rhs = z;
1338#endif // defined(MATRIX_B_DEPTH)
1339
1340    // Compute RHS matrix coordinates
1341    uint       x_rhs = (x % H0) * (uint)RHS_OFFSET_X;
1342    const uint y_rhs = (x / (uint)H0) + z_rhs * RHS_HEIGHT;
1343
1344    REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0);
1345    REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
1346
1347#if defined(REINTERPRET_INPUT_AS_3D)
1348
1349    // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1350    CALCULATE_Z_OFFSET(M0, uint, zin, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
1351
1352    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1353    // multiply lhs_stride_z by DEPTH_GEMM3D
1354    lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1355
1356#else // defined(REINTERPRET_INPUT_AS_3D)
1357
1358    // Add offset for batched GEMM
1359    lhs_offset += z * lhs_stride_z;
1360
1361#endif // defined(REINTERPRET_INPUT_AS_3D)
1362
1363    // Initialize the accumulators
1364    REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0);
1365
1366    int i = 0;
1367    for(; i <= (K - K0); i += K0)
1368    {
1369        // Load values from LHS matrix
1370        LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin);
1371
1372        VEC_DATA_TYPE(DATA_TYPE, N0)
1373        b0;
1374
1375        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 0 * RHS_STEP_X), (y_rhs));
1376        VFMA_M0xN0(0, a, b0, c);
1377        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 1 * RHS_STEP_X), (y_rhs));
1378        VFMA_M0xN0(1, a, b0, c);
1379#if K0 > 2
1380        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 2 * RHS_STEP_X), (y_rhs));
1381        VFMA_M0xN0(2, a, b0, c);
1382#endif // K0 > 2
1383#if K0 > 3
1384        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 3 * RHS_STEP_X), (y_rhs));
1385        VFMA_M0xN0(3, a, b0, c);
1386#endif // K0 > 3
1387#if K0 > 4
1388        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 4 * RHS_STEP_X), (y_rhs));
1389        VFMA_M0xN0(4, a, b0, c);
1390        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 5 * RHS_STEP_X), (y_rhs));
1391        VFMA_M0xN0(5, a, b0, c);
1392        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 6 * RHS_STEP_X), (y_rhs));
1393        VFMA_M0xN0(6, a, b0, c);
1394        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 7 * RHS_STEP_X), (y_rhs));
1395        VFMA_M0xN0(7, a, b0, c);
1396#endif // K0 > 4
1397#if K0 > 8
1398        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 8 * RHS_STEP_X), (y_rhs));
1399        VFMA_M0xN0(8, a, b0, c);
1400        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 9 * RHS_STEP_X), (y_rhs));
1401        VFMA_M0xN0(9, a, b0, c);
1402        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 10 * RHS_STEP_X), (y_rhs));
1403        VFMA_M0xN0(A, a, b0, c);
1404        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 11 * RHS_STEP_X), (y_rhs));
1405        VFMA_M0xN0(B, a, b0, c);
1406        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 12 * RHS_STEP_X), (y_rhs));
1407        VFMA_M0xN0(C, a, b0, c);
1408        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 13 * RHS_STEP_X), (y_rhs));
1409        VFMA_M0xN0(D, a, b0, c);
1410        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 14 * RHS_STEP_X), (y_rhs));
1411        VFMA_M0xN0(E, a, b0, c);
1412        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 15 * RHS_STEP_X), (y_rhs));
1413        VFMA_M0xN0(F, a, b0, c);
1414#endif // K0 > 8
1415
1416        lhs_offset += K0 * sizeof(DATA_TYPE);
1417        x_rhs += K0 * RHS_STEP_X * RHS_STEP_LOOP;
1418    }
1419
1420    // Left-over accumulations
1421    for(; i < K; ++i)
1422    {
1423        // Load values from LHS matrix
1424        VEC_DATA_TYPE(DATA_TYPE, 2)
1425        a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1426#if M0 > 1
1427        VEC_DATA_TYPE(DATA_TYPE, 2)
1428        a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1429#endif // M0 > 1
1430#if M0 > 2
1431        VEC_DATA_TYPE(DATA_TYPE, 2)
1432        a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1433#endif // M0 > 2
1434#if M0 > 3
1435        VEC_DATA_TYPE(DATA_TYPE, 2)
1436        a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1437#endif // M0 > 3
1438#if M0 > 4
1439        VEC_DATA_TYPE(DATA_TYPE, 2)
1440        a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1441#endif // M0 > 4
1442#if M0 > 5
1443        VEC_DATA_TYPE(DATA_TYPE, 2)
1444        a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1445#endif // M0 > 5
1446#if M0 > 6
1447        VEC_DATA_TYPE(DATA_TYPE, 2)
1448        a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1449#endif // M0 > 6
1450#if M0 > 7
1451        VEC_DATA_TYPE(DATA_TYPE, 2)
1452        a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
1453#endif // M0 > 7
1454
1455        VEC_DATA_TYPE(DATA_TYPE, N0)
1456        b0;
1457        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 0 * RHS_STEP_X), (y_rhs));
1458
1459        VFMA_M0xN0(0, a, b0, c);
1460
1461        lhs_offset += sizeof(DATA_TYPE);
1462        x_rhs += RHS_STEP_X;
1463    }
1464
1465    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * dst_stride_y);
1466
1467    REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1468
1469#if defined(REINTERPRET_OUTPUT_AS_3D)
1470    // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1471    CALCULATE_Z_OFFSET(M0, uint, zout, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
1472
1473    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1474    // multiply dst_stride_z by DEPTH_GEMM3D
1475    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1476
1477#else // defined(REINTERPRET_OUTPUT_AS_3D)
1478
1479    // Add offset for batched GEMM
1480    dst_addr += z * dst_stride_z;
1481
1482#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1483
1484    // Multiply by the weight of matrix-matrix product and store the result
1485#if defined(ALPHA)
1486    SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
1487#endif // defined(ALPHA)
1488
1489    // Add beta*bias
1490#if defined(BETA)
1491#if defined(BROADCAST_BIAS)
1492    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1493
1494    LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, 1, PARTIAL_STORE_N0, false, cond_x);
1495
1496#ifndef UNIT_BETA
1497    SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1498#endif // UNIT_BIAS
1499
1500    // c = c + bias[broadcasted]
1501    ADD_BLOCK_BROADCAST(M0, c, bias0);
1502
1503#else // defined(BROADCAST_BIAS)
1504    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * bias_stride_y) + z * bias_stride_z;
1505
1506    LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
1507
1508#ifndef UNIT_BETA
1509    SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1510#endif // UNIT_BIAS
1511
1512    // c = c + bias
1513    ADD_BLOCK(M0, c, bias);
1514
1515#endif // defined(BROADCAST_BIAS)
1516#endif // defined(BETA)
1517
1518#if defined(ACTIVATION_TYPE)
1519    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, N0, c, A_VAL, B_VAL);
1520#endif // defined(ACTIVATION_TYPE)
1521
1522    // Store output block
1523    STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
1524
1525#undef RHS_BLOCK_SIZE
1526#undef RHS_OFFSET_X
1527#undef RHS_STEP_X
1528#undef RHS_STEP_LOOP
1529}
1530#endif // defined(OPENCL_IMAGE_SUPPORT) && defined(GEMM_MM_RESHAPED_ONLY_RHS_NT_TEXTURE)
1531#endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE)
1532
1533#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(DATA_TYPE_ACCUMULATOR)
1534
1535#if defined(MIXED_PRECISION)
1536#if K0 == 2
1537#define ARM_DOT_K0(a, b, c) \
1538    ({                      \
1539        c += a.s0 * b.s0;   \
1540        c += a.s1 * b.s1;   \
1541    })
1542#elif K0 == 3 // K0 == 3
1543#define ARM_DOT_K0(a, b, c) \
1544    ({                      \
1545        c += a.s0 * b.s0;   \
1546        c += a.s1 * b.s1;   \
1547        c += a.s2 * b.s2;   \
1548    })
1549#elif K0 == 4 // K0 == 4
1550#define ARM_DOT_K0(a, b, c) \
1551    ({                      \
1552        c += a.s0 * b.s0;   \
1553        c += a.s1 * b.s1;   \
1554        c += a.s2 * b.s2;   \
1555        c += a.s3 * b.s3;   \
1556    })
1557#elif K0 == 8 // K0 == 8
1558#define ARM_DOT_K0(a, b, c) \
1559    ({                      \
1560        c += a.s0 * b.s0;   \
1561        c += a.s1 * b.s1;   \
1562        c += a.s2 * b.s2;   \
1563        c += a.s3 * b.s3;   \
1564        c += a.s4 * b.s4;   \
1565        c += a.s5 * b.s5;   \
1566        c += a.s6 * b.s6;   \
1567        c += a.s7 * b.s7;   \
1568    })
1569#elif K0 == 16 // K0 == 16
1570#define ARM_DOT_K0(a, b, c) \
1571    ({                      \
1572        c += a.s0 * b.s0;   \
1573        c += a.s1 * b.s1;   \
1574        c += a.s2 * b.s2;   \
1575        c += a.s3 * b.s3;   \
1576        c += a.s4 * b.s4;   \
1577        c += a.s5 * b.s5;   \
1578        c += a.s6 * b.s6;   \
1579        c += a.s7 * b.s7;   \
1580        c += a.s8 * b.s8;   \
1581        c += a.s9 * b.s9;   \
1582        c += a.sA * b.sA;   \
1583        c += a.sB * b.sB;   \
1584        c += a.sC * b.sC;   \
1585        c += a.sD * b.sD;   \
1586        c += a.sE * b.sE;   \
1587        c += a.sF * b.sF;   \
1588    })
1589#else // K0 not supported
1590#error "K0 value not supported"
1591#endif // K0 conditions
1592#else  // defined(MIXED_PRECISION)
1593#if K0 == 2
1594#define ARM_DOT_K0(a, b, c)     \
1595    ({                          \
1596        c = fma(a.s0, b.s0, c); \
1597        c = fma(a.s1, b.s1, c); \
1598    })
1599#elif K0 == 3 // K0 == 3
1600#define ARM_DOT_K0(a, b, c)     \
1601    ({                          \
1602        c = fma(a.s0, b.s0, c); \
1603        c = fma(a.s1, b.s1, c); \
1604        c = fma(a.s2, b.s2, c); \
1605    })
1606#elif K0 == 4 // K0 == 4
1607#define ARM_DOT_K0(a, b, c)     \
1608    ({                          \
1609        c = fma(a.s0, b.s0, c); \
1610        c = fma(a.s1, b.s1, c); \
1611        c = fma(a.s2, b.s2, c); \
1612        c = fma(a.s3, b.s3, c); \
1613    })
1614#elif K0 == 8 // K0 == 8
1615#define ARM_DOT_K0(a, b, c)     \
1616    ({                          \
1617        c = fma(a.s0, b.s0, c); \
1618        c = fma(a.s1, b.s1, c); \
1619        c = fma(a.s2, b.s2, c); \
1620        c = fma(a.s3, b.s3, c); \
1621        c = fma(a.s4, b.s4, c); \
1622        c = fma(a.s5, b.s5, c); \
1623        c = fma(a.s6, b.s6, c); \
1624        c = fma(a.s7, b.s7, c); \
1625    })
1626#elif K0 == 16 // K0 == 16
1627#define ARM_DOT_K0(a, b, c)     \
1628    ({                          \
1629        c = fma(a.s0, b.s0, c); \
1630        c = fma(a.s1, b.s1, c); \
1631        c = fma(a.s2, b.s2, c); \
1632        c = fma(a.s3, b.s3, c); \
1633        c = fma(a.s4, b.s4, c); \
1634        c = fma(a.s5, b.s5, c); \
1635        c = fma(a.s6, b.s6, c); \
1636        c = fma(a.s7, b.s7, c); \
1637        c = fma(a.s8, b.s8, c); \
1638        c = fma(a.s9, b.s9, c); \
1639        c = fma(a.sA, b.sA, c); \
1640        c = fma(a.sB, b.sB, c); \
1641        c = fma(a.sC, b.sC, c); \
1642        c = fma(a.sD, b.sD, c); \
1643        c = fma(a.sE, b.sE, c); \
1644        c = fma(a.sF, b.sF, c); \
1645    })
1646#else // K0 not supported
1647#error "K0 value not supported"
1648#endif // K0 conditions
1649#endif // defined(MIXED_PRECISION)
1650
1651#if defined(ARM_DOT_K0XN0)
1652#undef ARM_DOT_K0XN0
1653#endif // defined(ARM_DOT_K0XN0)
1654
1655#if N0 == 2
1656#define ARM_DOT_K0XN0(a, b, c)           \
1657    ({                                   \
1658        ARM_DOT_K0((a), (b##0), (c.s0)); \
1659        ARM_DOT_K0((a), (b##1), (c.s1)); \
1660    })
1661#elif N0 == 3 // N0 == 3
1662#define ARM_DOT_K0XN0(a, b, c)           \
1663    ({                                   \
1664        ARM_DOT_K0((a), (b##0), (c.s0)); \
1665        ARM_DOT_K0((a), (b##1), (c.s1)); \
1666        ARM_DOT_K0((a), (b##2), (c.s2)); \
1667    })
1668#elif N0 == 4 // N0 == 4
1669#define ARM_DOT_K0XN0(a, b, c)           \
1670    ({                                   \
1671        ARM_DOT_K0((a), (b##0), (c.s0)); \
1672        ARM_DOT_K0((a), (b##1), (c.s1)); \
1673        ARM_DOT_K0((a), (b##2), (c.s2)); \
1674        ARM_DOT_K0((a), (b##3), (c.s3)); \
1675    })
1676#elif N0 == 8 // N0 == 8
1677#define ARM_DOT_K0XN0(a, b, c)           \
1678    ({                                   \
1679        ARM_DOT_K0((a), (b##0), (c.s0)); \
1680        ARM_DOT_K0((a), (b##1), (c.s1)); \
1681        ARM_DOT_K0((a), (b##2), (c.s2)); \
1682        ARM_DOT_K0((a), (b##3), (c.s3)); \
1683        ARM_DOT_K0((a), (b##4), (c.s4)); \
1684        ARM_DOT_K0((a), (b##5), (c.s5)); \
1685        ARM_DOT_K0((a), (b##6), (c.s6)); \
1686        ARM_DOT_K0((a), (b##7), (c.s7)); \
1687    })
1688#elif N0 == 16 // N0 == 16
1689#define ARM_DOT_K0XN0(a, b, c)           \
1690    ({                                   \
1691        ARM_DOT_K0((a), (b##0), (c.s0)); \
1692        ARM_DOT_K0((a), (b##1), (c.s1)); \
1693        ARM_DOT_K0((a), (b##2), (c.s2)); \
1694        ARM_DOT_K0((a), (b##3), (c.s3)); \
1695        ARM_DOT_K0((a), (b##4), (c.s4)); \
1696        ARM_DOT_K0((a), (b##5), (c.s5)); \
1697        ARM_DOT_K0((a), (b##6), (c.s6)); \
1698        ARM_DOT_K0((a), (b##7), (c.s7)); \
1699        ARM_DOT_K0((a), (b##8), (c.s8)); \
1700        ARM_DOT_K0((a), (b##9), (c.s9)); \
1701        ARM_DOT_K0((a), (b##A), (c.sA)); \
1702        ARM_DOT_K0((a), (b##B), (c.sB)); \
1703        ARM_DOT_K0((a), (b##C), (c.sC)); \
1704        ARM_DOT_K0((a), (b##D), (c.sD)); \
1705        ARM_DOT_K0((a), (b##E), (c.sE)); \
1706        ARM_DOT_K0((a), (b##F), (c.sF)); \
1707    })
1708#else // N0 not supported
1709#error "N0 value not supported"
1710#endif // N0 conditions
1711
1712#if defined(GEMM_MM_RESHAPED_LHS_NT_RHS_T)
1713/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1714 *  The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
1715 *  The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
1716 * @note This kernel is duplicated in /experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl
1717 *
1718 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
1719 * @note The data type used for the accumulators must be passed at compile time using -DDATA_TYPE_ACCUMULATOR (e.g. -DDATA_TYPE_ACCUMULATOR=float)
1720 * @note The F16 computation also supports mixed precision through the option -DMIXED_PRECISION passed at compile time. If enabled, DATA_TYPE_ACCUMULATOR should be set to float
1721 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
1722 * @note The GEMM's dimensions M, N and K must be passed at compile time using -DM, -DN and -DK (e.g. -DM=52, -DN=90 and -DK=24).
1723 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
1724 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
1725 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
1726 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
1727 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1728 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
1729 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
1730 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1731 *  - M0 = 2, 3, 4, 5, 6, 7, 8
1732 *  - N0 = 2, 3, 4, 8, 16
1733 *  - K0 = 2, 3, 4, 8, 16
1734 *  - V0 >= 1
1735 *  - H0 >= 1
1736 *
1737 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
1738 *       The activation function is performed after the bias addition
1739 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
1740 *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1741 *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1742 *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1743 *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
1744 *
1745 * @param[in]  lhs_ptr                            Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1746 * @param[in]  lhs_stride_x                       Stride of the LHS reshaped matrix in X dimension (in bytes)
1747 * @param[in]  lhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
1748 * @param[in]  lhs_stride_y                       Stride of the LHS reshaped matrix in Y dimension (in bytes)
1749 * @param[in]  lhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
1750 * @param[in]  lhs_offset_first_element_in_bytes  The offset of the first element in the LHS reshaped matrix
1751 * @param[in]  rhs_ptr                            Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1752 * @param[in]  rhs_stride_x                       Stride of the RHS reshaped matrix in X dimension (in bytes)
1753 * @param[in]  rhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
1754 * @param[in]  rhs_stride_y                       Stride of the RHS reshaped matrix in Y dimension (in bytes)
1755 * @param[in]  rhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
1756 * @param[in]  rhs_offset_first_element_in_bytes  The offset of the first element in the RHS reshaped matrix
1757 * @param[in]  bias_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1758 * @param[in]  bias_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
1759 * @param[in]  bias_step_x                        (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
1760 * @param[in]  bias_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
1761 * @param[in]  bias_step_y                        (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1762 * @param[in]  bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1763 * @param[out] dst_ptr                            Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1764 * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
1765 * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
1766 * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
1767 * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
1768 * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
1769 * @param[in]  lhs_stride_z                       Stride of the LHS reshaped matrix in Z dimension (in bytes)
1770 * @param[in]  rhs_stride_z                       Stride of the RHS reshaped matrix in Z dimension (in bytes)
1771 * @param[in]  bias_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
1772 * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
1773 * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1774 * @param[in]  M                                  Number of rows in LHS matrix not reshaped.
1775 * @param[in]  N                                  Number of columns in RHS matrix not reshaped.
1776 * @param[in]  K                                  Number of columns in LHS matrix and rows in RHS matrix not reshaped.
1777 */
1778__kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
1779                                            IMAGE_DECLARATION(rhs),
1780#if defined(BETA)
1781                                            IMAGE_DECLARATION(bias),
1782#endif // defined(BETA)
1783                                            IMAGE_DECLARATION(dst),
1784                                            uint lhs_stride_z,
1785                                            uint rhs_stride_z,
1786#if defined(BETA)
1787                                            uint bias_stride_z,
1788#endif //defined(BETA)
1789                                            uint dst_stride_z
1790#if defined(REINTERPRET_OUTPUT_AS_3D)
1791                                            ,
1792                                            uint dst_cross_plane_pad
1793#endif // REINTERPRET_OUTPUT_AS_3D
1794                                            ,
1795                                            const int M,
1796                                            const int N,
1797                                            const int K)
1798{
1799    // Block size
1800#define LHS_BLOCK_SIZE ((K0) * (M0))
1801
1802#if defined(LHS_INTERLEAVE)
1803#define LHS_OFFSET_X (K0)
1804#define LHS_STEP_X ((K0) * (V0))
1805#define LHS_STEP_LOOP (1)
1806#else // defined(INTERLEAVE)
1807#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
1808#define LHS_STEP_X (K0)
1809#define LHS_STEP_LOOP (V0)
1810#endif // defined(INTERLEAVE)
1811
1812    // Block size
1813#define RHS_BLOCK_SIZE ((K0) * (N0))
1814
1815    // RHS offset and step X
1816#if defined(RHS_INTERLEAVE)
1817#define RHS_OFFSET_X (K0)
1818#define RHS_STEP_X ((K0) * (H0))
1819#define RHS_STEP_LOOP (1)
1820#else // defined(RHS_INTERLEAVE)
1821#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1822#define RHS_STEP_X (K0)
1823#define RHS_STEP_LOOP (H0)
1824#endif // defined(RHS_INTERLEAVE)
1825
1826#if defined(DUMMY_WORK_ITEMS)
1827    if((get_global_id(0) * N0 >= N) || (get_global_id(1) * M0 >= M))
1828    {
1829        return;
1830    }
1831#endif // defined(DUMMY_WORK_ITEMS)
1832
1833    // Compute LHS matrix address
1834    __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
1835                               (get_global_id(2) * lhs_stride_z);
1836
1837    // Compute RHS matrix address
1838    __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(0) / (uint)H0) * rhs_stride_y;
1839
1840#if defined(MATRIX_B_DEPTH)
1841    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1842    rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z;
1843#else  // defined(MATRIX_B_DEPTH)
1844    rhs_addr += get_global_id(2) * rhs_stride_z;
1845#endif // defined(MATRIX_B_DEPTH)
1846
1847    // Initialize the accumulators
1848    REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0), c, 0);
1849
1850    REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
1851    REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
1852
1853    for(int i = 0; i < K; i += K0)
1854    {
1855        // Supported cases (M0, K0):
1856        // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1857        // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1858        // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1859        // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1860        // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1861        // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1862        // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1863        // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1864        // Load values from LHS matrix
1865        LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
1866
1867        // Load values from RHS matrix
1868        LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zero);
1869
1870        // Accumulate
1871        ARM_DOT_K0XN0(a0, b, c0);
1872#if M0 > 1
1873        ARM_DOT_K0XN0(a1, b, c1);
1874#endif // M0 > 1
1875#if M0 > 2
1876        ARM_DOT_K0XN0(a2, b, c2);
1877#endif // M0 > 2
1878#if M0 > 3
1879        ARM_DOT_K0XN0(a3, b, c3);
1880#endif // M0 > 3
1881#if M0 > 4
1882        ARM_DOT_K0XN0(a4, b, c4);
1883#endif // M0 > 4
1884#if M0 > 5
1885        ARM_DOT_K0XN0(a5, b, c5);
1886#endif // M0 > 5
1887#if M0 > 6
1888        ARM_DOT_K0XN0(a6, b, c6);
1889#endif // M0 > 6
1890#if M0 > 7
1891        ARM_DOT_K0XN0(a7, b, c7);
1892#endif // M0 > 7
1893
1894        lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
1895        rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1896    }
1897
1898    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
1899
1900    REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
1901
1902    const bool cond_y = ((get_global_id(1) + 1) * M0 >= M);
1903    const bool cond_x = ((get_global_id(0) + 1) * N0 >= N);
1904
1905#if defined(REINTERPRET_OUTPUT_AS_3D)
1906
1907    // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1908    CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1) * (uint)M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
1909    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1910    // multiply dst_stride_z by DEPTH_GEMM3D
1911    dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
1912
1913#else // defined(REINTERPRET_OUTPUT_AS_3D)
1914
1915    // Add offset for batched GEMM
1916    dst_addr += get_global_id(2) * dst_stride_z;
1917
1918#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1919
1920    // Multiply by the weight of matrix-matrix product and store the result
1921#if defined(ALPHA)
1922    SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
1923#endif // defined(ALPHA)
1924
1925    // Add beta*bias
1926#if defined(BETA)
1927#if defined(BROADCAST_BIAS)
1928    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1929
1930    LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, 1, PARTIAL_STORE_N0, false, cond_x);
1931
1932#ifndef UNIT_BETA
1933    SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1934#endif // UNIT_BIAS
1935
1936    // c = c + bias[broadcasted]
1937#if defined(MIXED_PRECISION)
1938    CONVERT_BLOCK(1, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
1939    ADD_BLOCK_BROADCAST(M0, c, bias_hp0);
1940#else  // defined(MIXED_PRECISION)
1941    ADD_BLOCK_BROADCAST(M0, c, bias0);
1942#endif // defined(MIXED_PRECISION)
1943
1944#else // defined(BROADCAST_BIAS)
1945    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1946                                    2) * bias_stride_z;
1947
1948    LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
1949
1950#ifndef UNIT_BETA
1951    SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1952#endif // UNIT_BIAS
1953
1954    // c = c + bias
1955#if defined(MIXED_PRECISION)
1956    CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
1957    ADD_BLOCK(M0, c, bias_hp);
1958#else  // defined(MIXED_PRECISION)
1959    ADD_BLOCK(M0, c, bias);
1960#endif // defined(MIXED_PRECISION)
1961
1962#endif // defined(BROADCAST_BIAS)
1963#endif // defined(BETA)
1964
1965#if defined(ACTIVATION_TYPE)
1966#if defined(MIXED_PRECISION)
1967    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, N0, c, A_VAL, B_VAL);
1968#else  // defined(MIXED_PRECISION)
1969    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, N0, c, A_VAL, B_VAL);
1970#endif // defined(MIXED_PRECISION)
1971#endif // defined(ACTIVATION_TYPE)
1972
1973    // Store output block
1974#if defined(MIXED_PRECISION)
1975    CONVERT_BLOCK(M0, N0, DATA_TYPE, c, c_lp);
1976    STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c_lp, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
1977#else  // defined(MIXED_PRECISION)
1978    STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
1979#endif // defined(MIXED_PRECISION)
1980
1981#undef LHS_BLOCK_SIZE
1982#undef LHS_OFFSET_X
1983#undef LHS_STEP_X
1984#undef RHS_BLOCK_SIZE
1985#undef RHS_OFFSET_X
1986#undef RHS_STEP_X
1987#undef LHS_STEP_LOOP
1988#undef RHS_STEP_LOOP
1989}
1990#endif // defined(GEMM_MM_RESHAPED_LHS_NT_RHS_T)
1991
1992#if defined(OPENCL_IMAGE_SUPPORT) && defined(GEMM_MM_RESHAPED_LHS_NT_RHS_T_TEXTURE)
1993/** This OpenCL kernel computes the matrix multiplication between 2 matrices. The RHS matrix is stored in OpenCL image object.
1994 *  The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
1995 *  The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
1996 * @note This kernel is duplicated in /experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl
1997 *
1998 * @note -DOPENCL_IMAGE_SUPPORT must be passed at compile time in order to compile this OpenCL kernel
1999 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
2000 * @note The data type used for the accumulators must be passed at compile time using -DDATA_TYPE_ACCUMULATOR (e.g. -DDATA_TYPE_ACCUMULATOR=float)
2001 * @note The F16 computation also supports mixed precision through the option -DMIXED_PRECISION passed at compile time. If enabled, DATA_TYPE_ACCUMULATOR should be set to float
2002 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
2003 * @note The GEMM's dimensions M, N and K must be passed at compile time using -DM, -DN and -DK (e.g. -DM=52, -DN=90 and -DK=24).
2004 * @note The height of the RHS matrix, defined before creating the OpenCL image object from the OpenCL buffer, should be passed at compile time using -DRHS_HEIGHT=<value> (e.g. -DRHS_HEIGHT=32)
2005 *       Since we cannot create a 3d image from a buffer, the third dimension could be collapsed with the second dimension so RHS_HEIGHT
2006 *       could be different from the value returned by get_image_height(rhs_img).
2007 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
2008 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
2009 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
2010 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
2011 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
2012 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
2013 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
2014 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2015 *  - M0 = 2, 3, 4, 5, 6, 7, 8
2016 *  - N0 = 4, 8, 16
2017 *  - K0 = 4, 8, 16
2018 *  - V0 >= 1
2019 *  - H0 >= 1
2020 *
2021 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2022 *       The activation function is performed after the bias addition
2023 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
2024 *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2025 *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2026 *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2027 *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
2028 *
2029 * @param[in]  lhs_ptr                            Pointer to the LHS reshaped matrix. Supported data type: F32
2030 * @param[in]  lhs_stride_x                       Stride of the LHS reshaped matrix in X dimension (in bytes)
2031 * @param[in]  lhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
2032 * @param[in]  lhs_stride_y                       Stride of the LHS reshaped matrix in Y dimension (in bytes)
2033 * @param[in]  lhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
2034 * @param[in]  lhs_offset_first_element_in_bytes  The offset of the first element in the LHS reshaped matrix
2035 * @param[in]  rhs_img                            The RHS reshaped matrix as OpenCL image object. Supported data type: same as @p lhs_ptr
2036 * @param[in]  bias_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2037 * @param[in]  bias_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
2038 * @param[in]  bias_step_x                        (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
2039 * @param[in]  bias_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
2040 * @param[in]  bias_step_y                        (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
2041 * @param[in]  bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2042 * @param[out] dst_ptr                            Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2043 * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
2044 * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
2045 * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
2046 * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
2047 * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
2048 * @param[in]  lhs_stride_z                       Stride of the LHS reshaped matrix in Z dimension (in bytes)
2049 * @param[in]  rhs_stride_z                       Stride of the RHS reshaped matrix in Z dimension (in bytes)
2050 * @param[in]  bias_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
2051 * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
2052 * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2053 * @param[in]  M                                  Number of rows in LHS matrix not reshaped.
2054 * @param[in]  N                                  Number of columns in RHS matrix not reshaped.
2055 * @param[in]  K                                  Number of columns in LHS matrix and rows in RHS matrix not reshaped.
2056 */
2057__kernel void gemm_mm_reshaped_lhs_nt_rhs_t_texture(IMAGE_DECLARATION(lhs),
2058                                                    __read_only image2d_t rhs_img,
2059#if defined(BETA)
2060                                                    IMAGE_DECLARATION(bias),
2061#endif // defined(BETA)
2062                                                    IMAGE_DECLARATION(dst),
2063                                                    uint lhs_stride_z,
2064                                                    uint rhs_stride_z,
2065#if defined(BETA)
2066                                                    uint bias_stride_z,
2067#endif //defined(BETA)
2068                                                    uint dst_stride_z
2069#if defined(REINTERPRET_OUTPUT_AS_3D)
2070                                                    ,
2071                                                    uint dst_cross_plane_pad
2072#endif // REINTERPRET_OUTPUT_AS_3D
2073                                                    ,
2074                                                    const int M,
2075                                                    const int N,
2076                                                    const int K)
2077{
2078    // Pixel unit
2079#define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(K0)
2080
2081    // Block size
2082#define LHS_BLOCK_SIZE ((K0) * (M0))
2083
2084#if defined(LHS_INTERLEAVE)
2085#define LHS_OFFSET_X (K0)
2086#define LHS_STEP_X ((K0) * (V0))
2087#define LHS_STEP_LOOP (1)
2088#else // defined(INTERLEAVE)
2089#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
2090#define LHS_STEP_X (K0)
2091#define LHS_STEP_LOOP (V0)
2092#endif // defined(INTERLEAVE)
2093
2094    // Block size
2095#define RHS_BLOCK_SIZE (PIXEL_UNIT * (N0))
2096
2097    // RHS offset and step X
2098#if defined(RHS_INTERLEAVE)
2099#define RHS_OFFSET_X (PIXEL_UNIT)
2100#define RHS_STEP_X (PIXEL_UNIT * (H0))
2101#define RHS_STEP_LOOP (1)
2102#else // defined(RHS_INTERLEAVE)
2103#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2104#define RHS_STEP_X PIXEL_UNIT
2105#define RHS_STEP_LOOP (H0)
2106#endif // defined(RHS_INTERLEAVE)
2107
2108#if defined(DUMMY_WORK_ITEMS)
2109    if((get_global_id(0) * N0 >= N) || (get_global_id(1) * M0 >= M))
2110    {
2111        return;
2112    }
2113#endif // defined(DUMMY_WORK_ITEMS)
2114
2115    // Compute LHS matrix address
2116    __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
2117                               (get_global_id(2) * lhs_stride_z);
2118
2119#if defined(MATRIX_B_DEPTH)
2120    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2121    const uint z_rhs = (get_global_id(2) % MATRIX_B_DEPTH);
2122#else  // defined(MATRIX_B_DEPTH)
2123    const uint z_rhs = get_global_id(2);
2124#endif // defined(MATRIX_B_DEPTH)
2125
2126    // Compute RHS matrix coordinates
2127    uint       x_rhs = (get_global_id(0) % H0) * (uint)RHS_OFFSET_X;
2128    const uint y_rhs = (get_global_id(0) / (uint)H0) + z_rhs * RHS_HEIGHT;
2129
2130    // Initialize the accumulators
2131    REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0), c, 0);
2132
2133    REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
2134    REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
2135
2136    for(int i = 0; i < K; i += K0)
2137    {
2138        // Load values from LHS matrix
2139        LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
2140
2141        // Load values from RHS matrix stored in a cl_image
2142        REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), b, 0);
2143        LOAD_TEXTURE2D(N0, PIXEL_UNIT, DATA_TYPE, b, rhs_img, x_rhs, y_rhs, RHS_STEP_X, 0);
2144
2145        // Accumulate
2146        ARM_DOT_K0XN0(a0, b, c0);
2147#if M0 > 1
2148        ARM_DOT_K0XN0(a1, b, c1);
2149#endif // M0 > 1
2150#if M0 > 2
2151        ARM_DOT_K0XN0(a2, b, c2);
2152#endif // M0 > 2
2153#if M0 > 3
2154        ARM_DOT_K0XN0(a3, b, c3);
2155#endif // M0 > 3
2156#if M0 > 4
2157        ARM_DOT_K0XN0(a4, b, c4);
2158#endif // M0 > 4
2159#if M0 > 5
2160        ARM_DOT_K0XN0(a5, b, c5);
2161#endif // M0 > 5
2162#if M0 > 6
2163        ARM_DOT_K0XN0(a6, b, c6);
2164#endif // M0 > 6
2165#if M0 > 7
2166        ARM_DOT_K0XN0(a7, b, c7);
2167#endif // M0 > 7
2168
2169        lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
2170
2171        x_rhs += N0 * RHS_STEP_X * RHS_STEP_LOOP;
2172    }
2173
2174    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
2175
2176    REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
2177
2178    const bool cond_y = ((get_global_id(1) + 1) * M0 >= M);
2179    const bool cond_x = ((get_global_id(0) + 1) * N0 >= N);
2180
2181#if defined(REINTERPRET_OUTPUT_AS_3D)
2182
2183    // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2184    CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1) * (uint)M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2185    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2186    // multiply dst_stride_z by DEPTH_GEMM3D
2187    dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
2188
2189#else // defined(REINTERPRET_OUTPUT_AS_3D)
2190
2191    // Add offset for batched GEMM
2192    dst_addr += get_global_id(2) * dst_stride_z;
2193
2194#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2195
2196    // Multiply by the weight of matrix-matrix product and store the result
2197#if defined(ALPHA)
2198    SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
2199#endif // defined(ALPHA)
2200
2201    // Add beta*bias
2202#if defined(BETA)
2203#if defined(BROADCAST_BIAS)
2204    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
2205
2206    LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, 1, PARTIAL_STORE_N0, false, cond_x);
2207
2208#ifndef UNIT_BETA
2209    SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2210#endif // UNIT_BIAS
2211
2212    // c = c + bias[broadcasted]
2213#if defined(MIXED_PRECISION)
2214    CONVERT_BLOCK(1, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2215    ADD_BLOCK_BROADCAST(M0, c, bias_hp0);
2216#else  // defined(MIXED_PRECISION)
2217    ADD_BLOCK_BROADCAST(M0, c, bias0);
2218#endif // defined(MIXED_PRECISION)
2219
2220#else // defined(BROADCAST_BIAS)
2221    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
2222                                    2) * bias_stride_z;
2223
2224    LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
2225
2226#ifndef UNIT_BETA
2227    SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2228#endif // UNIT_BIAS
2229
2230    // c = c + bias
2231#if defined(MIXED_PRECISION)
2232    CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2233    ADD_BLOCK(M0, c, bias_hp);
2234#else  // defined(MIXED_PRECISION)
2235    ADD_BLOCK(M0, c, bias);
2236#endif // defined(MIXED_PRECISION)
2237
2238#endif // defined(BROADCAST_BIAS)
2239#endif // defined(BETA)
2240
2241#if defined(ACTIVATION_TYPE)
2242#if defined(MIXED_PRECISION)
2243    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, N0, c, A_VAL, B_VAL);
2244#else  // defined(MIXED_PRECISION)
2245    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, N0, c, A_VAL, B_VAL);
2246#endif // defined(MIXED_PRECISION)
2247#endif // defined(ACTIVATION_TYPE)
2248
2249    // Store output block
2250#if defined(MIXED_PRECISION)
2251    CONVERT_BLOCK(M0, N0, DATA_TYPE, c, c_lp);
2252    STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c_lp, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
2253#else  // defined(MIXED_PRECISION)
2254    STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
2255#endif // defined(MIXED_PRECISION)
2256
2257#undef LHS_BLOCK_SIZE
2258#undef LHS_OFFSET_X
2259#undef LHS_STEP_X
2260#undef RHS_BLOCK_SIZE
2261#undef RHS_OFFSET_X
2262#undef RHS_STEP_X
2263#undef PIXEL_UNIT
2264#undef LHS_STEP_LOOP
2265#undef RHS_STEP_LOOP
2266}
2267#endif // defined(OPENCL_IMAGE_SUPPORT) && defined(GEMM_MM_RESHAPED_LHS_NT_RHS_T_TEXTURE)
2268
2269#if defined(LHS_TRANSPOSE)
2270
2271#define VTYPE(TYPE, SIZE) VEC_DATA_TYPE(TYPE, SIZE)
2272
2273#if defined(MIXED_PRECISION)
2274
2275#if(GPU_ARCH == GPU_ARCH_MIDGARD)
2276#define ARM_VFMA(N0, a, b, c) c += (CONVERT(a, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))) * (CONVERT(b, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0)));
2277#else // GPU_ARCH == GPU_ARCH_MIDGARD
2278#define ARM_VFMA(N0, a, b, c) c = fma((CONVERT(a, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))), (CONVERT(b, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))), (c));
2279#endif // GPU_ARCH == GPU_ARCH_MIDGARD
2280
2281#else // defined(MIXED_PRECISION
2282
2283#if(GPU_ARCH == GPU_ARCH_MIDGARD)
2284#define ARM_VFMA(N0, a, b, c) c += (a) * (b);
2285#else // GPU_ARCH == GPU_ARCH_MIDGARD
2286#define ARM_VFMA(N0, a, b, c) c = fma((a), (b), (c));
2287#endif // GPU_ARCH == GPU_ARCH_MIDGARD
2288
2289#endif // defined(MIXED_PRECISION)
2290
2291#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C)         \
2292    ({                                                 \
2293        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a), b, (C##0)); \
2294    })
2295#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C)            \
2296    ({                                                    \
2297        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s0), b, (C##0)); \
2298        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s1), b, (C##1)); \
2299    })
2300#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C)            \
2301    ({                                                    \
2302        ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C);           \
2303        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s2), b, (C##2)); \
2304    })
2305#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C)            \
2306    ({                                                    \
2307        ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C);           \
2308        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s3), b, (C##3)); \
2309    })
2310#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C)            \
2311    ({                                                    \
2312        ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C);           \
2313        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s4), b, (C##4)); \
2314        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s5), b, (C##5)); \
2315        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s6), b, (C##6)); \
2316        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s7), b, (C##7)); \
2317    })
2318
2319// Factory macro for the column-vector (transposed) by row-vector (not transposed) multiplication. K0 = 1
2320// a is the column-vector (transposed)
2321// b is the row-vector (not transposed)
2322// C is the output matrix
2323// Lower case is a vector (a, b)
2324// Upper case is a matrix (C)
2325#define ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, a, b, C) ARM_VVM_T_NT_##M0##xN0x1(N0, TYPE, a, b, C)
2326
2327#define ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, A, B, C)             \
2328    ({                                                         \
2329        ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##0), (B##0), C); \
2330    })
2331#define ARM_MM_T_NT_M0xN0x2(M0, N0, TYPE, A, B, C)             \
2332    ({                                                         \
2333        ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, A, B, C);            \
2334        ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##1), (B##1), C); \
2335    })
2336#define ARM_MM_T_NT_M0xN0x3(M0, N0, TYPE, A, B, C)             \
2337    ({                                                         \
2338        ARM_MM_T_NT_M0xN0x2(M0, N0, TYPE, A, B, C);            \
2339        ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##2), (B##2), C); \
2340    })
2341#define ARM_MM_T_NT_M0xN0x4(M0, N0, TYPE, A, B, C)             \
2342    ({                                                         \
2343        ARM_MM_T_NT_M0xN0x3(M0, N0, TYPE, A, B, C);            \
2344        ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##3), (B##3), C); \
2345    })
2346#define ARM_MM_T_NT_M0xN0x8(M0, N0, TYPE, A, B, C)             \
2347    ({                                                         \
2348        ARM_MM_T_NT_M0xN0x4(M0, N0, TYPE, A, B, C);            \
2349        ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##4), (B##4), C); \
2350        ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##5), (B##5), C); \
2351        ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##6), (B##6), C); \
2352        ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##7), (B##7), C); \
2353    })
2354#define ARM_MM_T_NT_M0xN0x16(M0, N0, TYPE, A, B, C)           \
2355    ({                                                        \
2356        ARM_MM_T_NT_M0xN0x8(M0, N0, TYPE, A, B, C);           \
2357        ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##8), (B##8), C); \
2358        ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##9), (B##9), C); \
2359        ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##A), (B##A), C); \
2360        ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##B), (B##B), C); \
2361        ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##C), (B##C), C); \
2362        ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##D), (B##D), C); \
2363        ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##E), (B##E), C); \
2364        ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##F), (B##F), C); \
2365    })
2366
2367// Factory macro for the matrix (transposed) by matrix (not transposed) multiplication.
2368// The dimensions for this matrix multiplications are defined through M0, N0 and K0
2369// The dimensions supported are:
2370// M0: 1, 2, 3, 4, 8
2371// N0: 1, 2, 3, 4, 8, 16
2372// K0: 1, 2, 3, 4, 8, 16
2373// This macro calls the vector-by-matrix macro K0 times
2374// A, B and C are matrices
2375#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) \
2376    CONCAT(ARM_MM_T_NT_M0xN0x, K0)             \
2377    (M0, N0, TYPE, A, B, C)
2378
2379#if defined(GEMM_MM_RESHAPED_LHS_T_RHS_NT)
2380/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
2381 *  The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be transposed
2382 *  The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be NOT transposed
2383 * @note This kernel is duplicated in /experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl
2384 *
2385 * @note LHS_TRANSPOSE should be passed at compile time in order to compile this OpenCL kernel (e.g. -DLHS_TRANSPOSE).
2386 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
2387 * @note The GEMM's dimensions M, N and K must be passed at compile time using -DM, -DN and -DK (e.g. -DM=52, -DN=90 and -DK=24).
2388 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
2389 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
2390 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
2391 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
2392 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
2393 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
2394 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
2395 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2396 *  - M0 = 2, 3, 4, 8
2397 *  - N0 = 2, 3, 4, 8, 16
2398 *  - K0 = 2, 3, 4, 8, 16
2399 *  - V0 >= 1
2400 *  - H0 >= 1
2401 *
2402 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2403 *       The activation function is performed after the bias addition
2404 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
2405 *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2406 *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2407 *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2408 *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
2409 *
2410 * @param[in]  lhs_ptr                            Pointer to the LHS reshaped matrix. Supported data type: F16/F32
2411 * @param[in]  lhs_stride_x                       Stride of the LHS reshaped matrix in X dimension (in bytes)
2412 * @param[in]  lhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
2413 * @param[in]  lhs_stride_y                       Stride of the LHS reshaped matrix in Y dimension (in bytes)
2414 * @param[in]  lhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
2415 * @param[in]  lhs_offset_first_element_in_bytes  The offset of the first element in the LHS reshaped matrix
2416 * @param[in]  rhs_ptr                            Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
2417 * @param[in]  rhs_stride_x                       Stride of the RHS reshaped matrix in X dimension (in bytes)
2418 * @param[in]  rhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
2419 * @param[in]  rhs_stride_y                       Stride of the RHS reshaped matrix in Y dimension (in bytes)
2420 * @param[in]  rhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
2421 * @param[in]  rhs_offset_first_element_in_bytes  The offset of the first element in the RHS reshaped matrix
2422 * @param[in]  bias_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2423 * @param[in]  bias_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
2424 * @param[in]  bias_step_x                        (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
2425 * @param[in]  bias_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
2426 * @param[in]  bias_step_y                        (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
2427 * @param[in]  bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2428 * @param[out] dst_ptr                            Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2429 * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
2430 * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
2431 * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
2432 * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
2433 * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
2434 * @param[in]  lhs_stride_z                       Stride of the LHS reshaped matrix in Z dimension (in bytes)
2435 * @param[in]  rhs_stride_z                       Stride of the RHS reshaped matrix in Z dimension (in bytes)
2436 * @param[in]  bias_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
2437 * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
2438 * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2439 * @param[in]  M                                  Number of rows in LHS matrix not reshaped.
2440 * @param[in]  N                                  Number of columns in RHS matrix not reshaped.
2441 * @param[in]  K                                  Number of columns in LHS matrix and rows in RHS matrix not reshaped.
2442 */
2443__kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs),
2444                                            IMAGE_DECLARATION(rhs),
2445#if defined(BETA)
2446                                            IMAGE_DECLARATION(bias),
2447#endif // defined(BETA)
2448                                            IMAGE_DECLARATION(dst),
2449                                            uint lhs_stride_z,
2450                                            uint rhs_stride_z,
2451#if defined(BETA)
2452                                            uint bias_stride_z,
2453#endif //defined(BETA)
2454                                            uint dst_stride_z
2455#if defined(REINTERPRET_OUTPUT_AS_3D)
2456                                            ,
2457                                            uint dst_cross_plane_pad
2458#endif // REINTERPRET_OUTPUT_AS_3D
2459                                            ,
2460                                            const int M,
2461                                            const int N,
2462                                            const int K)
2463{
2464    // Block size
2465#define LHS_BLOCK_SIZE ((K0) * (M0))
2466
2467#if defined(LHS_INTERLEAVE)
2468#define LHS_OFFSET_X (M0)
2469#define LHS_STEP_X ((M0) * (V0))
2470#define LHS_STEP_LOOP (1)
2471#else // defined(INTERLEAVE)
2472#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
2473#define LHS_STEP_X (M0)
2474#define LHS_STEP_LOOP (V0)
2475#endif // defined(INTERLEAVE)
2476
2477    // Block size
2478#define RHS_BLOCK_SIZE ((K0) * (N0))
2479
2480    // RHS offset and step X
2481#if defined(RHS_INTERLEAVE)
2482#define RHS_OFFSET_X (N0)
2483#define RHS_STEP_X ((N0) * (H0))
2484#else // defined(RHS_INTERLEAVE)
2485#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2486#define RHS_STEP_X (N0)
2487#endif // defined(RHS_INTERLEAVE)
2488
2489    const uint x = get_global_id(0);
2490    const uint y = get_global_id(1);
2491    const uint z = get_global_id(2);
2492
2493    const bool cond_y = ((get_global_id(1) + 1) * M0 >= M);
2494    const bool cond_x = ((get_global_id(0) + 1) * N0 >= N);
2495
2496#if defined(DUMMY_WORK_ITEMS)
2497    if((x * N0 >= N) || (y * M0 >= M))
2498    {
2499        return;
2500    }
2501#endif // defined(DUMMY_WORK_ITEMS)
2502
2503    // Compute LHS matrix address
2504    __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (y % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (y / V0) * (uint)lhs_stride_y + (z * lhs_stride_z);
2505
2506    // Compute RHS matrix address
2507    __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
2508
2509#if defined(MATRIX_B_DEPTH)
2510    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2511    rhs_addr += (z % MATRIX_B_DEPTH) * rhs_stride_z;
2512#else  // defined(MATRIX_B_DEPTH)
2513    rhs_addr += z * rhs_stride_z;
2514#endif // defined(MATRIX_B_DEPTH)
2515
2516    // Initialize the accumulators
2517    REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0), c, 0);
2518
2519    REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
2520
2521    __global DATA_TYPE *lhs = (__global DATA_TYPE *)(lhs_addr);
2522    __global DATA_TYPE *rhs = (__global DATA_TYPE *)(rhs_addr);
2523
2524    for(int i = 0; i < K; i += K0)
2525    {
2526        VEC_DATA_TYPE(DATA_TYPE, M0)
2527        a0;
2528        VEC_DATA_TYPE(DATA_TYPE, N0)
2529        b0;
2530
2531        a0 = VLOAD(M0)(0, lhs);
2532        b0 = VLOAD(N0)(0, rhs);
2533
2534        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2535
2536        lhs += LHS_STEP_X;
2537        rhs += RHS_STEP_X;
2538
2539#if K0 > 1
2540        a0 = VLOAD(M0)(0, lhs);
2541        b0 = VLOAD(N0)(0, rhs);
2542
2543        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2544
2545        lhs += LHS_STEP_X;
2546        rhs += RHS_STEP_X;
2547#endif // K0 > 1
2548
2549#if K0 > 2
2550        a0 = VLOAD(M0)(0, lhs);
2551        b0 = VLOAD(N0)(0, rhs);
2552
2553        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2554
2555        lhs += LHS_STEP_X;
2556        rhs += RHS_STEP_X;
2557#endif // K0 > 2
2558
2559#if K0 > 3
2560        a0 = VLOAD(M0)(0, lhs);
2561        b0 = VLOAD(N0)(0, rhs);
2562
2563        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2564
2565        lhs += LHS_STEP_X;
2566        rhs += RHS_STEP_X;
2567#endif // K0 > 3
2568
2569#if K0 > 4
2570        a0 = VLOAD(M0)(0, lhs);
2571        b0 = VLOAD(N0)(0, rhs);
2572
2573        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2574
2575        lhs += LHS_STEP_X;
2576        rhs += RHS_STEP_X;
2577
2578        a0 = VLOAD(M0)(0, lhs);
2579        b0 = VLOAD(N0)(0, rhs);
2580
2581        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2582
2583        lhs += LHS_STEP_X;
2584        rhs += RHS_STEP_X;
2585
2586        a0 = VLOAD(M0)(0, lhs);
2587        b0 = VLOAD(N0)(0, rhs);
2588
2589        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2590
2591        lhs += LHS_STEP_X;
2592        rhs += RHS_STEP_X;
2593
2594        a0 = VLOAD(M0)(0, lhs);
2595        b0 = VLOAD(N0)(0, rhs);
2596
2597        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2598
2599        lhs += LHS_STEP_X;
2600        rhs += RHS_STEP_X;
2601#endif // K0 > 4
2602
2603#if K0 > 8
2604        a0 = VLOAD(M0)(0, lhs);
2605        b0 = VLOAD(N0)(0, rhs);
2606
2607        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2608
2609        lhs += LHS_STEP_X;
2610        rhs += RHS_STEP_X;
2611
2612        a0 = VLOAD(M0)(0, lhs);
2613        b0 = VLOAD(N0)(0, rhs);
2614
2615        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2616
2617        lhs += LHS_STEP_X;
2618        rhs += RHS_STEP_X;
2619
2620        a0 = VLOAD(M0)(0, lhs);
2621        b0 = VLOAD(N0)(0, rhs);
2622
2623        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2624
2625        lhs += LHS_STEP_X;
2626        rhs += RHS_STEP_X;
2627
2628        a0 = VLOAD(M0)(0, lhs);
2629        b0 = VLOAD(N0)(0, rhs);
2630
2631        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2632
2633        lhs += LHS_STEP_X;
2634        rhs += RHS_STEP_X;
2635
2636        a0 = VLOAD(M0)(0, lhs);
2637        b0 = VLOAD(N0)(0, rhs);
2638
2639        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2640
2641        lhs += LHS_STEP_X;
2642        rhs += RHS_STEP_X;
2643
2644        a0 = VLOAD(M0)(0, lhs);
2645        b0 = VLOAD(N0)(0, rhs);
2646
2647        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2648
2649        lhs += LHS_STEP_X;
2650        rhs += RHS_STEP_X;
2651
2652        a0 = VLOAD(M0)(0, lhs);
2653        b0 = VLOAD(N0)(0, rhs);
2654
2655        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2656
2657        lhs += LHS_STEP_X;
2658        rhs += RHS_STEP_X;
2659
2660        a0 = VLOAD(M0)(0, lhs);
2661        b0 = VLOAD(N0)(0, rhs);
2662
2663        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2664
2665        lhs += LHS_STEP_X;
2666        rhs += RHS_STEP_X;
2667#endif // K0 > 8
2668
2669#ifndef LHS_INTERLEAVE
2670        lhs += (M0 * K0 * (V0 - 1));
2671#endif // LHS_INTERLEAVE
2672
2673#ifndef RHS_INTERLEAVE
2674        rhs += (N0 * K0 * (H0 - 1));
2675#endif // RHS_INTERLEAVE
2676    }
2677
2678    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
2679
2680    REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
2681
2682#if defined(REINTERPRET_OUTPUT_AS_3D)
2683
2684    // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2685    CALCULATE_Z_OFFSET(M0, uint, zout, y * (uint)M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2686    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2687    // multiply dst_stride_z by DEPTH_GEMM3D
2688    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2689
2690#else // defined(REINTERPRET_OUTPUT_AS_3D)
2691
2692    // Add offset for batched GEMM
2693    dst_addr += z * dst_stride_z;
2694
2695#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2696
2697    // Multiply by the weight of matrix-matrix product and store the result
2698#if defined(ALPHA)
2699    SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
2700#endif // defined(ALPHA)
2701
2702    // Add beta*bias
2703#if defined(BETA)
2704#if defined(BROADCAST_BIAS)
2705    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE));
2706
2707    LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, 1, PARTIAL_STORE_N0, false, cond_x);
2708
2709#ifndef UNIT_BETA
2710    SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2711#endif // UNIT_BIAS
2712
2713    // c = c + bias[broadcasted]
2714#if defined(MIXED_PRECISION)
2715    CONVERT_BLOCK(1, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2716    ADD_BLOCK_BROADCAST(M0, c, bias_hp0);
2717#else  // defined(MIXED_PRECISION)
2718    ADD_BLOCK_BROADCAST(M0, c, bias0);
2719#endif // defined(MIXED_PRECISION)
2720
2721#else // defined(BROADCAST_BIAS)
2722    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
2723                                    2) * bias_stride_z;
2724
2725    LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
2726
2727#ifndef UNIT_BETA
2728    SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2729#endif // UNIT_BIAS
2730
2731#if defined(MIXED_PRECISION)
2732    CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2733    ADD_BLOCK(M0, c, bias_hp);
2734#else  // defined(MIXED_PRECISION)
2735    ADD_BLOCK(M0, c, bias);
2736#endif // defined(MIXED_PRECISION)
2737
2738#endif // defined(BROADCAST_BIAS)
2739#endif // defined(BETA)
2740
2741#if defined(ACTIVATION_TYPE)
2742#if defined(MIXED_PRECISION)
2743    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, N0, c, A_VAL, B_VAL);
2744#else  // defined(MIXED_PRECISION)
2745    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, N0, c, A_VAL, B_VAL);
2746#endif // defined(MIXED_PRECISION)
2747#endif // defined(ACTIVATION_TYPE)
2748
2749    // Store output block
2750#if defined(MIXED_PRECISION)
2751    CONVERT_BLOCK(M0, N0, DATA_TYPE, c, c_lp);
2752    STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c_lp, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
2753#else  // defined(MIXED_PRECISION)
2754    STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
2755#endif // defined(MIXED_PRECISION)
2756
2757#undef LHS_BLOCK_SIZE
2758#undef LHS_OFFSET_X
2759#undef LHS_STEP_X
2760#undef RHS_BLOCK_SIZE
2761#undef RHS_OFFSET_X
2762#undef RHS_STEP_X
2763}
2764#endif // defined(GEMM_MM_RESHAPED_LHS_T_RHS_NT)
2765
2766#if defined(OPENCL_IMAGE_SUPPORT) && defined(GEMM_MM_RESHAPED_LHS_T_RHS_NT_TEXTURE)
2767/** This OpenCL kernel computes the matrix multiplication between 2 matrices. The RHS matrix is stored in OpenCL image object.
2768 *  The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be transposed
2769 *  The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be NOT transposed
2770 * @note This kernel is duplicated in /experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl
2771 *
2772 * @note -DOPENCL_IMAGE_SUPPORT must be passed at compile time in order to compile this OpenCL kernel
2773 * @note LHS_TRANSPOSE should be passed at compile time in order to compile this OpenCL kernel (e.g. -DLHS_TRANSPOSE).
2774 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
2775 * @note The GEMM's dimensions M, N and K must be passed at runtime.
2776 * @note The height of the RHS matrix, defined before creating the OpenCL image object from the OpenCL buffer, should be passed at compile time using -DRHS_HEIGHT=<value> (e.g. -DRHS_HEIGHT=32)
2777 *       Since we cannot create a 3d image from a buffer, the third dimension could be collapsed with the second dimension so RHS_HEIGHT
2778 *       could be different from the value returned by get_image_height(rhs_img).
2779 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
2780 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
2781 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
2782 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
2783 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
2784 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
2785 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
2786 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2787 *  - M0 = 2, 3, 4, 8
2788 *  - N0 = 4, 8, 16
2789 *  - K0 = 4, 8, 16
2790 *  - V0 >= 1
2791 *  - H0 >= 1
2792 *
2793 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2794 *       The activation function is performed after the bias addition
2795 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
2796 *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2797 *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2798 *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2799 *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
2800 *
2801 * @param[in]  lhs_ptr                            Pointer to the LHS reshaped matrix. Supported data type: F32
2802 * @param[in]  lhs_stride_x                       Stride of the LHS reshaped matrix in X dimension (in bytes)
2803 * @param[in]  lhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
2804 * @param[in]  lhs_stride_y                       Stride of the LHS reshaped matrix in Y dimension (in bytes)
2805 * @param[in]  lhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
2806 * @param[in]  lhs_offset_first_element_in_bytes  The offset of the first element in the LHS reshaped matrix
2807 * @param[in]  rhs_img                            The RHS reshaped matrix as cl_image 2d. Supported data type: same as @p lhs_ptr
2808 * @param[in]  bias_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2809 * @param[in]  bias_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
2810 * @param[in]  bias_step_x                        (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
2811 * @param[in]  bias_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
2812 * @param[in]  bias_step_y                        (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
2813 * @param[in]  bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2814 * @param[out] dst_ptr                            Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2815 * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
2816 * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
2817 * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
2818 * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
2819 * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
2820 * @param[in]  lhs_stride_z                       Stride of the LHS reshaped matrix in Z dimension (in bytes)
2821 * @param[in]  rhs_stride_z                       Stride of the RHS reshaped matrix in Z dimension (in bytes)
2822 * @param[in]  bias_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
2823 * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
2824 * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2825 * @param[in]  M                                  Number of rows in LHS matrix not reshaped.
2826 * @param[in]  N                                  Number of columns in RHS matrix not reshaped.
2827 * @param[in]  K                                  Number of columns in LHS matrix and rows in RHS matrix not reshaped.
2828 */
2829__kernel void gemm_mm_reshaped_lhs_t_rhs_nt_texture(IMAGE_DECLARATION(lhs),
2830                                                    __read_only image2d_t rhs_img,
2831#if defined(BETA)
2832                                                    IMAGE_DECLARATION(bias),
2833#endif // defined(BETA)
2834                                                    IMAGE_DECLARATION(dst),
2835                                                    uint lhs_stride_z,
2836                                                    uint rhs_stride_z,
2837#if defined(BETA)
2838                                                    uint bias_stride_z,
2839#endif //defined(BETA)
2840                                                    uint dst_stride_z
2841#if defined(REINTERPRET_OUTPUT_AS_3D)
2842                                                    ,
2843                                                    uint dst_cross_plane_pad
2844#endif // REINTERPRET_OUTPUT_AS_3D
2845                                                    ,
2846                                                    const int M,
2847                                                    const int N,
2848                                                    const int K)
2849{
2850    // Pixel unit
2851#define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(N0)
2852
2853    // Block size
2854#define LHS_BLOCK_SIZE ((K0) * (M0))
2855
2856#if defined(LHS_INTERLEAVE)
2857#define LHS_OFFSET_X (M0)
2858#define LHS_STEP_X ((M0) * (V0))
2859#define LHS_STEP_LOOP (1)
2860#else // defined(INTERLEAVE)
2861#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
2862#define LHS_STEP_X (M0)
2863#define LHS_STEP_LOOP (V0)
2864#endif // defined(INTERLEAVE)
2865
2866    // Block size
2867#define RHS_BLOCK_SIZE ((K0) * (PIXEL_UNIT))
2868
2869    // RHS offset and step X
2870#if defined(RHS_INTERLEAVE)
2871#define RHS_OFFSET_X (PIXEL_UNIT)
2872#define RHS_STEP_X ((PIXEL_UNIT) * (H0))
2873#else // defined(RHS_INTERLEAVE)
2874#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2875#define RHS_STEP_X (PIXEL_UNIT)
2876#endif // defined(RHS_INTERLEAVE)
2877
2878    const uint x = get_global_id(0);
2879    const uint y = get_global_id(1);
2880    const uint z = get_global_id(2);
2881
2882#if defined(DUMMY_WORK_ITEMS)
2883    if((x * N0 >= N) || (y * M0 >= M))
2884    {
2885        return;
2886    }
2887#endif // defined(DUMMY_WORK_ITEMS)
2888
2889    // Compute LHS matrix address
2890    __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (y % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (y / V0) * (uint)lhs_stride_y + (z * lhs_stride_z);
2891
2892#if defined(MATRIX_B_DEPTH)
2893    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2894    const uint z_rhs = (z % MATRIX_B_DEPTH);
2895#else  // defined(MATRIX_B_DEPTH)
2896    const uint z_rhs = z;
2897#endif // defined(MATRIX_B_DEPTH)
2898
2899    // Compute RHS matrix coordinates
2900    uint       x_rhs = (x % H0) * (uint)RHS_OFFSET_X;
2901    const uint y_rhs = (x / (uint)H0) + z_rhs * RHS_HEIGHT;
2902
2903    // Initialize the accumulators
2904    REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0), c, 0);
2905
2906    REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
2907
2908    __global DATA_TYPE *lhs = (__global DATA_TYPE *)(lhs_addr);
2909
2910    for(int i = 0; i < K; i += K0)
2911    {
2912        VEC_DATA_TYPE(DATA_TYPE, M0)
2913        a0;
2914        VEC_DATA_TYPE(DATA_TYPE, N0)
2915        b0;
2916
2917        a0 = VLOAD(M0)(0, lhs);
2918        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 0 * RHS_STEP_X), (y_rhs));
2919
2920        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2921
2922        lhs += LHS_STEP_X;
2923
2924#if K0 > 1
2925        a0 = VLOAD(M0)(0, lhs);
2926        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 1 * RHS_STEP_X), (y_rhs));
2927
2928        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2929
2930        lhs += LHS_STEP_X;
2931#endif // K0 > 1
2932
2933#if K0 > 2
2934        a0 = VLOAD(M0)(0, lhs);
2935        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 2 * RHS_STEP_X), (y_rhs));
2936
2937        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2938
2939        lhs += LHS_STEP_X;
2940#endif // K0 > 2
2941
2942#if K0 > 3
2943        a0 = VLOAD(M0)(0, lhs);
2944        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 3 * RHS_STEP_X), (y_rhs));
2945
2946        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2947
2948        lhs += LHS_STEP_X;
2949#endif // K0 > 3
2950
2951#if K0 > 4
2952        a0 = VLOAD(M0)(0, lhs);
2953        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 4 * RHS_STEP_X), (y_rhs));
2954
2955        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2956
2957        lhs += LHS_STEP_X;
2958
2959        a0 = VLOAD(M0)(0, lhs);
2960        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 5 * RHS_STEP_X), (y_rhs));
2961
2962        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2963
2964        lhs += LHS_STEP_X;
2965
2966        a0 = VLOAD(M0)(0, lhs);
2967        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 6 * RHS_STEP_X), (y_rhs));
2968
2969        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2970
2971        lhs += LHS_STEP_X;
2972
2973        a0 = VLOAD(M0)(0, lhs);
2974        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 7 * RHS_STEP_X), (y_rhs));
2975
2976        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2977
2978        lhs += LHS_STEP_X;
2979#endif // K0 > 4
2980
2981#if K0 > 8
2982        a0 = VLOAD(M0)(0, lhs);
2983        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 8 * RHS_STEP_X), (y_rhs));
2984
2985        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2986
2987        lhs += LHS_STEP_X;
2988
2989        a0 = VLOAD(M0)(0, lhs);
2990        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 9 * RHS_STEP_X), (y_rhs));
2991
2992        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2993
2994        lhs += LHS_STEP_X;
2995
2996        a0 = VLOAD(M0)(0, lhs);
2997        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 10 * RHS_STEP_X), (y_rhs));
2998
2999        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3000
3001        lhs += LHS_STEP_X;
3002
3003        a0 = VLOAD(M0)(0, lhs);
3004        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 11 * RHS_STEP_X), (y_rhs));
3005
3006        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3007
3008        lhs += LHS_STEP_X;
3009
3010        a0 = VLOAD(M0)(0, lhs);
3011        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 12 * RHS_STEP_X), (y_rhs));
3012
3013        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3014
3015        lhs += LHS_STEP_X;
3016
3017        a0 = VLOAD(M0)(0, lhs);
3018        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 13 * RHS_STEP_X), (y_rhs));
3019
3020        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3021
3022        lhs += LHS_STEP_X;
3023
3024        a0 = VLOAD(M0)(0, lhs);
3025        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 14 * RHS_STEP_X), (y_rhs));
3026
3027        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3028
3029        lhs += LHS_STEP_X;
3030
3031        a0 = VLOAD(M0)(0, lhs);
3032        b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 15 * RHS_STEP_X), (y_rhs));
3033
3034        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3035
3036        lhs += LHS_STEP_X;
3037#endif // K0 > 8
3038
3039#ifndef LHS_INTERLEAVE
3040        lhs += (M0 * K0 * (V0 - 1));
3041#endif // LHS_INTERLEAVE
3042
3043        x_rhs += K0 * RHS_STEP_X;
3044#ifndef RHS_INTERLEAVE
3045        x_rhs += (PIXEL_UNIT * K0 * (H0 - 1));
3046#endif // RHS_INTERLEAVE
3047    }
3048
3049    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
3050
3051    REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
3052
3053    const bool cond_y = ((get_global_id(1) + 1) * M0 >= M);
3054    const bool cond_x = ((get_global_id(0) + 1) * N0 >= N);
3055
3056#if defined(REINTERPRET_OUTPUT_AS_3D)
3057
3058    // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
3059    CALCULATE_Z_OFFSET(M0, uint, zout, y * (uint)M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
3060    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3061    // multiply dst_stride_z by DEPTH_GEMM3D
3062    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3063
3064#else // defined(REINTERPRET_OUTPUT_AS_3D)
3065
3066    // Add offset for batched GEMM
3067    dst_addr += z * dst_stride_z;
3068
3069#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3070
3071    // Multiply by the weight of matrix-matrix product and store the result
3072#if defined(ALPHA)
3073    SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
3074#endif // defined(ALPHA)
3075
3076    // Add beta*bias
3077#if defined(BETA)
3078#if defined(BROADCAST_BIAS)
3079    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE));
3080
3081    LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, 1, PARTIAL_STORE_N0, false, cond_x);
3082
3083#ifndef UNIT_BETA
3084    SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
3085#endif // UNIT_BIAS
3086
3087    // c = c + bias[broadcasted]
3088#if defined(MIXED_PRECISION)
3089    CONVERT_BLOCK(1, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
3090    ADD_BLOCK_BROADCAST(M0, c, bias_hp0);
3091#else  // defined(MIXED_PRECISION)
3092    ADD_BLOCK_BROADCAST(M0, c, bias0);
3093#endif // defined(MIXED_PRECISION)
3094
3095#else // defined(BROADCAST_BIAS)
3096    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * bias_stride_y) + z * bias_stride_z;
3097
3098    LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
3099
3100#ifndef UNIT_BETA
3101    SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
3102#endif // UNIT_BIAS
3103
3104#if defined(MIXED_PRECISION)
3105    CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
3106    ADD_BLOCK(M0, c, bias_hp);
3107#else  // defined(MIXED_PRECISION)
3108    ADD_BLOCK(M0, c, bias);
3109#endif // defined(MIXED_PRECISION)
3110
3111#endif // defined(BROADCAST_BIAS)
3112#endif // defined(BETA)
3113
3114#if defined(ACTIVATION_TYPE)
3115#if defined(MIXED_PRECISION)
3116    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, N0, c, A_VAL, B_VAL);
3117#else  // defined(MIXED_PRECISION)
3118    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, N0, c, A_VAL, B_VAL);
3119#endif // defined(MIXED_PRECISION)
3120#endif // defined(ACTIVATION_TYPE)
3121
3122    // Store output block
3123#if defined(MIXED_PRECISION)
3124    CONVERT_BLOCK(M0, N0, DATA_TYPE, c, c_lp);
3125    STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c_lp, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
3126#else  // defined(MIXED_PRECISION)
3127    STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
3128#endif // defined(MIXED_PRECISION)
3129
3130#undef LHS_BLOCK_SIZE
3131#undef LHS_OFFSET_X
3132#undef LHS_STEP_X
3133#undef RHS_BLOCK_SIZE
3134#undef RHS_OFFSET_X
3135#undef RHS_STEP_X
3136#undef PIXEL_UNIT
3137#undef LHS_STEP_LOOP
3138#undef RHS_STEP_LOOP
3139}
3140#endif // defined(OPENCL_IMAGE_SUPPORT) && defined(GEMM_MM_RESHAPED_LHS_T_RHS_NT_TEXTURE)
3141
3142#endif // defined(LHS_TRANSPOSE)
3143
3144#endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(DATA_TYPE_ACCUMULATOR)
3145
3146#if defined(M0) && defined(N0) && defined(K0) && defined(DATA_TYPE)
3147
3148#define VFMA(a, b, c)     \
3149    ({                    \
3150        c = fma(a, b, c); \
3151    })
3152
3153#if M0 == 1
3154#define RHS_VFMA_M0xN0(i, a, b, c)                                    \
3155    ({                                                                \
3156        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3157    })
3158#elif M0 == 2 // M0 == 2
3159#define RHS_VFMA_M0xN0(i, a, b, c)                                    \
3160    ({                                                                \
3161        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3162        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3163    })
3164#elif M0 == 3 // M0 == 3
3165#define RHS_VFMA_M0xN0(i, a, b, c)                                    \
3166    ({                                                                \
3167        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3168        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3169        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3170    })
3171#elif M0 == 4 // M0 == 4
3172#define RHS_VFMA_M0xN0(i, a, b, c)                                    \
3173    ({                                                                \
3174        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3175        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3176        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3177        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3178    })
3179#elif M0 == 5 // M0 == 5
3180#define RHS_VFMA_M0xN0(i, a, b, c)                                    \
3181    ({                                                                \
3182        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3183        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3184        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3185        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3186        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
3187    })
3188#elif M0 == 6 // M0 == 6
3189#define RHS_VFMA_M0xN0(i, a, b, c)                                    \
3190    ({                                                                \
3191        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3192        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3193        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3194        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3195        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
3196        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
3197    })
3198#elif M0 == 7 // M0 == 7
3199#define RHS_VFMA_M0xN0(i, a, b, c)                                    \
3200    ({                                                                \
3201        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3202        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3203        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3204        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3205        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
3206        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
3207        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
3208    })
3209#elif M0 == 8 // M0 == 8
3210#define RHS_VFMA_M0xN0(i, a, b, c)                                    \
3211    ({                                                                \
3212        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3213        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3214        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3215        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3216        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
3217        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
3218        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
3219        VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
3220    })
3221#else // M0 not supported
3222#error "M0 not supported"
3223#endif // M0 not supported
3224
3225#if defined(GEMM_MM_NATIVE)
3226/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
3227 *  The LHS matrix is NOT reshaped
3228 *  The RHS matrix is NOT reshaped
3229 * @note This kernel is duplicated in /experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl
3230 *
3231 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
3232 * @note The GEMM's dimensions (M,N and K) must be passed at runtime as kernel parameters.
3233 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
3234 * @note The number of K0 partial accumulations must be passed at compile time using -DK0 (e.g., -DK0=2)
3235 * @note The number of N0 columns to process must be passed at compile time using -DN0 (e.g. -DN0=2)
3236 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
3237 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
3238 * @note Only the following configurations of M0, N0 and K0 are currently supported:
3239 *  - M0 = 1, 2, 3, 4, 5, 6, 7, 8
3240 *  - N0 = 2, 3, 4, 8, 16
3241 *  - K0 = 2, 3, 4, 8, 16
3242 *
3243 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
3244 *       The activation function is performed after the bias addition
3245 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3246 *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
3247 *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3248 *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3249 *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3250 *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
3251 *
3252 * @param[in]  lhs_ptr                            Pointer to the LHS matrix. Supported data type: F16/F32
3253 * @param[in]  lhs_stride_x                       Stride of the LHS matrix in X dimension (in bytes)
3254 * @param[in]  lhs_step_x                         lhs_stride_x * number of elements along X processed per workitem(in bytes)
3255 * @param[in]  lhs_stride_y                       Stride of the LHS matrix in Y dimension (in bytes)
3256 * @param[in]  lhs_step_y                         lhs_stride_y * number of elements along Y processed per workitem(in bytes)
3257 * @param[in]  lhs_offset_first_element_in_bytes  The offset of the first element in the LHS matrix
3258 * @param[in]  rhs_ptr                            Pointer to the RHS matrix. Supported data type: same as @p lhs_ptr
3259 * @param[in]  rhs_stride_x                       Stride of the RHS matrix in X dimension (in bytes)
3260 * @param[in]  rhs_step_x                         rhs_stride_x * number of elements along X processed per workitem(in bytes)
3261 * @param[in]  rhs_stride_y                       Stride of the RHS matrix in Y dimension (in bytes)
3262 * @param[in]  rhs_step_y                         rhs_stride_y * number of elements along Y processed per workitem(in bytes)
3263 * @param[in]  rhs_offset_first_element_in_bytes  The offset of the first element in the RHS matrix
3264 * @param[in]  bias_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3265 * @param[in]  bias_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
3266 * @param[in]  bias_step_x                        (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
3267 * @param[in]  bias_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
3268 * @param[in]  bias_step_y                        (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
3269 * @param[in]  bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
3270 * @param[out] dst_ptr                            Pointer to the destination matrix Supported data type: same as @p lhs_ptr
3271 * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
3272 * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
3273 * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
3274 * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
3275 * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
3276 * @param[in]  lhs_stride_z                       Stride of the LHS matrix in Z dimension (in bytes)
3277 * @param[in]  rhs_stride_z                       Stride of the RHS matrix in Z dimension (in bytes)
3278 * @param[in]  bias_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
3279 * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
3280 * @param[in]  M                                  Number of rows in LHS matrix not reshaped.
3281 * @param[in]  N                                  Number of columns in RHS matrix not reshaped.
3282 * @param[in]  K                                  Number of columns in LHS matrix and rows in RHS matrix not reshaped.
3283 * @param[in]  lhs_cross_plane_pad                (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
3284 * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
3285 */
3286__kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
3287                             IMAGE_DECLARATION(rhs),
3288#if defined(BETA)
3289                             IMAGE_DECLARATION(bias),
3290#endif // defined(BETA)
3291                             IMAGE_DECLARATION(dst),
3292                             uint lhs_stride_z,
3293                             uint rhs_stride_z,
3294#if defined(BETA)
3295                             uint bias_stride_z,
3296#endif //defined(BETA)
3297                             uint      dst_stride_z,
3298                             const int M,
3299                             const int N,
3300                             const int K
3301#if defined(REINTERPRET_INPUT_AS_3D)
3302                             ,
3303                             uint lhs_cross_plane_pad
3304#endif // REINTERPRET_INPUT_AS_3D
3305#if defined(REINTERPRET_OUTPUT_AS_3D)
3306                             ,
3307                             uint dst_cross_plane_pad
3308#endif // REINTERPRET_OUTPUT_AS_3D
3309                            )
3310{
3311    // Block size
3312#define RHS_BLOCK_SIZE ((K0) * (N0))
3313
3314    // RHS offset and step X
3315#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
3316
3317    uint x = get_global_id(0);
3318    uint y = get_global_id(1);
3319    uint z = get_global_id(2);
3320
3321#if defined(DUMMY_WORK_ITEMS)
3322    if((x * N0 >= N) || (y * M0 >= M))
3323    {
3324        return;
3325    }
3326#endif // defined(DUMMY_WORK_ITEMS)
3327
3328    // Compute LHS matrix address
3329    uint lhs_offset = lhs_offset_first_element_in_bytes + COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * (uint)lhs_stride_y;
3330
3331    // Compute RHS matrix address
3332    uint rhs_offset = rhs_offset_first_element_in_bytes + x * N0 * sizeof(DATA_TYPE);
3333
3334#if defined(MATRIX_B_DEPTH)
3335    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3336    rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
3337#else  // defined(MATRIX_B_DEPTH)
3338    rhs_offset += z * rhs_stride_z;
3339#endif // defined(MATRIX_B_DEPTH)
3340
3341    REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0);
3342    REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
3343
3344#if defined(REINTERPRET_INPUT_AS_3D)
3345    // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
3346    CALCULATE_Z_OFFSET(M0, uint, zlhs, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
3347
3348    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3349    // multiply lhs_stride_z by DEPTH_GEMM3D
3350    lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
3351
3352#else // defined(REINTERPRET_INPUT_AS_3D)
3353
3354    // Add offset for batched GEMM
3355    lhs_offset += z * lhs_stride_z;
3356
3357#endif // defined(REINTERPRET_INPUT_AS_3D)
3358
3359    // Initialize the accumulators
3360    REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0)    c0=0,c1=0,c2=0,... c(M0-1)=0;
3361
3362    int i = 0;
3363#if K0 > 1
3364    for(; i <= (K - K0); i += K0)
3365    {
3366        // Supported cases (M0, K0):
3367        // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
3368        // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
3369        // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
3370        // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
3371        // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
3372        // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
3373        // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
3374        // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
3375        // Load values from LHS matrix
3376        LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
3377
3378        // Load values from RHS matrix
3379        LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_ptr, rhs_offset, rhs_stride_y, zero);
3380
3381        RHS_VFMA_M0xN0(0, a, b0, c);
3382        RHS_VFMA_M0xN0(1, a, b1, c);
3383#if K0 > 2
3384        RHS_VFMA_M0xN0(2, a, b2, c);
3385#endif // K0 > 2
3386#if K0 > 3
3387        RHS_VFMA_M0xN0(3, a, b3, c);
3388#endif // K0 > 3
3389#if K0 > 4
3390        RHS_VFMA_M0xN0(4, a, b4, c);
3391        RHS_VFMA_M0xN0(5, a, b5, c);
3392        RHS_VFMA_M0xN0(6, a, b6, c);
3393        RHS_VFMA_M0xN0(7, a, b7, c);
3394#endif // K0 > 4
3395#if K0 > 8
3396        RHS_VFMA_M0xN0(8, a, b8, c);
3397        RHS_VFMA_M0xN0(9, a, b9, c);
3398        RHS_VFMA_M0xN0(A, a, bA, c);
3399        RHS_VFMA_M0xN0(B, a, bB, c);
3400        RHS_VFMA_M0xN0(C, a, bC, c);
3401        RHS_VFMA_M0xN0(D, a, bD, c);
3402        RHS_VFMA_M0xN0(E, a, bE, c);
3403        RHS_VFMA_M0xN0(F, a, bF, c);
3404#endif // K0 > 8
3405
3406        lhs_offset += K0 * sizeof(DATA_TYPE);
3407        rhs_offset += K0 * rhs_stride_y;
3408    }
3409#endif // K0 > 1
3410    // Left-over accumulations
3411    for(; i < K; ++i)
3412    {
3413        // Load values from LHS matrix
3414        VEC_DATA_TYPE(DATA_TYPE, 2)
3415        a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zlhs0));
3416#if M0 > 1
3417        VEC_DATA_TYPE(DATA_TYPE, 2)
3418        a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zlhs1));
3419#endif // M0 > 1
3420#if M0 > 2
3421        VEC_DATA_TYPE(DATA_TYPE, 2)
3422        a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zlhs2));
3423#endif // M0 > 2
3424#if M0 > 3
3425        VEC_DATA_TYPE(DATA_TYPE, 2)
3426        a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zlhs3));
3427#endif // M0 > 3
3428#if M0 > 4
3429        VEC_DATA_TYPE(DATA_TYPE, 2)
3430        a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zlhs4));
3431#endif // M0 > 4
3432#if M0 > 5
3433        VEC_DATA_TYPE(DATA_TYPE, 2)
3434        a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zlhs5));
3435#endif // M0 > 5
3436#if M0 > 6
3437        VEC_DATA_TYPE(DATA_TYPE, 2)
3438        a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zlhs6));
3439#endif // M0 > 6
3440#if M0 > 7
3441        VEC_DATA_TYPE(DATA_TYPE, 2)
3442        a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zlhs7));
3443#endif // M0 > 7
3444
3445        VEC_DATA_TYPE(DATA_TYPE, N0)
3446        b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * rhs_stride_y));
3447        RHS_VFMA_M0xN0(0, a, b, c);
3448
3449        lhs_offset += sizeof(DATA_TYPE);
3450        rhs_offset += rhs_stride_y;
3451    }
3452
3453    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * dst_stride_y);
3454
3455    REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
3456
3457#if defined(REINTERPRET_OUTPUT_AS_3D)
3458    // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
3459    CALCULATE_Z_OFFSET(M0, uint, zout, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
3460
3461    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3462    // multiply dst_stride_z by DEPTH_GEMM3D
3463    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3464
3465#else // defined(REINTERPRET_OUTPUT_AS_3D)
3466
3467    // Add offset for batched GEMM
3468    dst_addr += z * dst_stride_z;
3469
3470#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3471
3472    // Multiply by the weight of matrix-matrix product and store the result
3473#if defined(ALPHA)
3474    SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
3475#endif // defined(ALPHA)
3476
3477    // Add beta*bias
3478#if defined(BETA)
3479#if defined(BROADCAST_BIAS)
3480    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
3481
3482    LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
3483
3484#ifndef UNIT_BETA
3485    SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
3486#endif // UNIT_BIAS
3487
3488    // c = c + bias[broadcasted]
3489    ADD_BLOCK_BROADCAST(M0, c, bias0);
3490
3491#else // defined(BROADCAST_BIAS)
3492    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * bias_stride_y) + z * bias_stride_z;
3493
3494    LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
3495
3496#ifndef UNIT_BETA
3497    SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
3498#endif // UNIT_BIAS
3499
3500    // c = c + bias
3501    ADD_BLOCK(M0, c, bias);
3502
3503#endif // defined(BROADCAST_BIAS)
3504#endif // defined(BETA)
3505
3506#if defined(ACTIVATION_TYPE)
3507    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, N0, c, A_VAL, B_VAL);
3508#endif // defined(ACTIVATION_TYPE)
3509
3510    const bool cond_y = y == 0;
3511    const bool cond_x = ((x + 1) * N0 >= N);
3512
3513    // Store output block
3514    STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
3515}
3516#endif // defined(GEMM_MM_NATIVE)
3517#endif // defined(M0) && defined(N0) && defined(K0) && defined(DATA_TYPE)
3518
3519#if defined(BETA)
3520/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
3521 *
3522 * @note The beta's value need to be passed at compile time using -DBETA
3523 *
3524 * @param[in]  src_ptr                           Pointer to the source matrix. Supported data types: F32
3525 * @param[in]  src_stride_x                      Stride of the source matrix in X dimension (in bytes)
3526 * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
3527 * @param[in]  src_stride_y                      Stride of the source matrix in Y dimension (in bytes)
3528 * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
3529 * @param[in]  src_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
3530 * @param[in]  src_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
3531 * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source matrix
3532 * @param[out] dst_ptr                           Pointer to the destination matrix Supported data types: same as @p src_ptr
3533 * @param[in]  dst_stride_x                      Stride of the destination matrix in X dimension (in bytes)
3534 * @param[in]  dst_step_x                        dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3535 * @param[in]  dst_stride_y                      Stride of the destination matrix in Y dimension (in bytes)
3536 * @param[in]  dst_step_y                        dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3537 * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
3538 * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
3539 * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3540 */
3541__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
3542                          TENSOR3D_DECLARATION(dst))
3543{
3544    // Compute source and destination addresses
3545    Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
3546    Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
3547
3548    // Load values from A x B
3549    float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
3550
3551    // Load values from Matrix C
3552    float4 c = vload4(0, (__global float *)src.ptr);
3553
3554    // Computes alpha * axb + beta * c
3555    float4 out = alpha_ab + (float4)BETA * c;
3556
3557    // Store final result in axb matrix
3558    vstore4(out, 0, (__global float *)dst.ptr);
3559}
3560
3561#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
3562/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
3563 *
3564 * @note The beta's value need to be passed at compile time using -DBETA
3565 *
3566 * @param[in]  src_ptr                           Pointer to the source matrix. Supported data types: F16
3567 * @param[in]  src_stride_x                      Stride of the source matrix in X dimension (in bytes)
3568 * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
3569 * @param[in]  src_stride_y                      Stride of the source matrix in Y dimension (in bytes)
3570 * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
3571 * @param[in]  src_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
3572 * @param[in]  src_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
3573 * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source matrix
3574 * @param[out] dst_ptr                           Pointer to the destination matrix Supported data types: same as @p src_ptr
3575 * @param[in]  dst_stride_x                      Stride of the destination matrix in X dimension (in bytes)
3576 * @param[in]  dst_step_x                        dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3577 * @param[in]  dst_stride_y                      Stride of the destination matrix in Y dimension (in bytes)
3578 * @param[in]  dst_step_y                        dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3579 * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
3580 * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
3581 * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3582 */
3583__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
3584                          TENSOR3D_DECLARATION(dst))
3585{
3586    // Compute source and destination addresses
3587    Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
3588    Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
3589
3590    // Load values from A x B
3591    half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
3592
3593    // Load values from Matrix C
3594    half8 c = vload8(0, (__global half *)src.ptr);
3595
3596    // Computes alpha * axb + beta * c
3597    half8 out = alpha_ab + (half8)BETA * c;
3598
3599    // Store final result in axb matrix
3600    vstore8(out, 0, (__global half *)dst.ptr);
3601}
3602#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
3603#endif // defined(BETA)
3604