xref: /aosp_15_r20/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014-2015 Benoit Steiner <[email protected]>
5 // Copyright (C) 2015 Navdeep Jaitly <[email protected]>
6 // Copyright (C) 2014 Eric Martin <[email protected]>
7 //
8 // This Source Code Form is subject to the terms of the Mozilla
9 // Public License v. 2.0. If a copy of the MPL was not distributed
10 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
11 
12 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
13 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
14 
15 #if defined(EIGEN_USE_GPU) && defined(EIGEN_GPUCC)
16 
17 namespace Eigen {
18 
19 template<typename Scalar, typename Index, typename LhsMapper,
20          typename RhsMapper, typename OutputMapper, bool needs_edge_check>
21 __device__ EIGEN_STRONG_INLINE void
EigenContractionKernelInternal(const LhsMapper lhs,const RhsMapper rhs,const OutputMapper output,Scalar * lhs_shmem,Scalar * rhs_shmem,const Index m_size,const Index n_size,const Index k_size)22 EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
23                                const OutputMapper output, Scalar* lhs_shmem, Scalar* rhs_shmem,
24                        const Index m_size, const Index n_size, const Index k_size) {
25 
26   const Index m_block_idx = blockIdx.x;
27   const Index n_block_idx = blockIdx.y;
28 
29   const Index base_m = 64 * m_block_idx;
30   const Index base_n = 64 * n_block_idx;
31 
32   // declare and initialize 64 registers for output 8x8 block
33 
34   // prefetch registers
35   Scalar lhs_pf0;
36   Scalar lhs_pf1;
37   Scalar lhs_pf2;
38   Scalar lhs_pf3;
39   Scalar lhs_pf4;
40   Scalar lhs_pf5;
41   Scalar lhs_pf6;
42   Scalar lhs_pf7;
43 
44   Scalar rhs_pf0;
45   Scalar rhs_pf1;
46   Scalar rhs_pf2;
47   Scalar rhs_pf3;
48   Scalar rhs_pf4;
49   Scalar rhs_pf5;
50   Scalar rhs_pf6;
51   Scalar rhs_pf7;
52 
53   // shared memory is formatted
54   // (contract idx in block, nocontract idx in block, block idx)
55   // where block idx is column major. This transposition limits the number of
56   // bank conflicts when reading the LHS. The core idea is that since the contracting
57   // index is shared by both sides, then the contracting index should be in threadIdx.x.
58 
59   // On the LHS, we pad each row inside of each block with an extra element. This makes
60   // each block 8 rows of 9 elements, which is 72 elements. This gives no bank conflicts
61   // on writes and very few 2-way conflicts on reads. There is an 8x8 grid of these blocks.
62 
63   // On the RHS we just add 8 padding elements to the end of each block. This gives no bank
64   // conflicts on writes and also none on reads.
65 
66   // storage indices
67   const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
68   const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;
69 
70   const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
71   const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
72   const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
73   const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
74   const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
75   const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
76   const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
77   const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;
78 
79   const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
80   const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
81   const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
82   const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
83   const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
84   const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
85   const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
86   const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
87 
88   // in the loading code, the following variables are important:
89   // threadIdx.x: the vertical position in an 8x8 block
90   // threadIdx.y: the vertical index of the 8x8 block in the grid
91   // threadIdx.z: the horizontal position in an 8x8 block
92   // k: the horizontal index of the 8x8 block in the grid
93   //
94   // The k parameter is implicit (it was the loop counter for a loop that went
95   // from 0 to <8, but now that loop is unrolled in the below code.
96 
97   const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
98   const Index lhs_vert = base_m + load_idx_vert;
99 
100 #define prefetchIntoRegisters(base_k)                           \
101   {                                                             \
102     lhs_pf0 = conv(0);                                          \
103     lhs_pf1 = conv(0);                                          \
104     lhs_pf2 = conv(0);                                          \
105     lhs_pf3 = conv(0);                                          \
106     lhs_pf4 = conv(0);                                          \
107     lhs_pf5 = conv(0);                                          \
108     lhs_pf6 = conv(0);                                          \
109     lhs_pf7 = conv(0);                                          \
110                                                                 \
111     rhs_pf0 = conv(0);                                          \
112     rhs_pf1 = conv(0);                                          \
113     rhs_pf2 = conv(0);                                          \
114     rhs_pf3 = conv(0);                                          \
115     rhs_pf4 = conv(0);                                          \
116     rhs_pf5 = conv(0);                                          \
117     rhs_pf6 = conv(0);                                          \
118     rhs_pf7 = conv(0);                                          \
119                                                                 \
120     if (!needs_edge_check || lhs_vert < m_size) {               \
121       const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8;   \
122       const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8;   \
123       const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8;   \
124       const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8;   \
125       const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8;   \
126       const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8;   \
127       const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8;   \
128       const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8;   \
129                                                                 \
130       if (!needs_edge_check || lhs_horiz_7 < k_size) {          \
131         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
132         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
133         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
134         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
135         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
136         lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
137         lhs_pf6 = lhs(lhs_vert, lhs_horiz_6);                   \
138         lhs_pf7 = lhs(lhs_vert, lhs_horiz_7);                   \
139       } else if (lhs_horiz_6 < k_size) {                        \
140         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
141         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
142         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
143         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
144         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
145         lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
146         lhs_pf6 = lhs(lhs_vert, lhs_horiz_6);                   \
147       } else if (lhs_horiz_5 < k_size) {                        \
148         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
149         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
150         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
151         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
152         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
153         lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
154       } else if (lhs_horiz_4 < k_size) {                        \
155         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
156         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
157         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
158         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
159         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
160       } else if (lhs_horiz_3 < k_size) {                        \
161         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
162         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
163         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
164         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
165       } else if (lhs_horiz_2 < k_size) {                        \
166         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
167         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
168         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
169       } else if (lhs_horiz_1 < k_size) {                        \
170         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
171         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
172       } else if (lhs_horiz_0 < k_size) {                        \
173         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
174       }                                                         \
175     }                                                           \
176                                                                 \
177     const Index rhs_vert = base_k + load_idx_vert;              \
178     if (!needs_edge_check || rhs_vert < k_size) {               \
179       const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8;   \
180       const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8;   \
181       const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8;   \
182       const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8;   \
183       const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8;   \
184       const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8;   \
185       const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8;   \
186       const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8;   \
187                                                                 \
188       if (rhs_horiz_7 < n_size) {                               \
189         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
190         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
191         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
192         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
193         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
194         rhs_pf5 = rhs(rhs_vert, rhs_horiz_5);                   \
195         rhs_pf6 = rhs(rhs_vert, rhs_horiz_6);                   \
196         rhs_pf7 = rhs(rhs_vert, rhs_horiz_7);                   \
197       } else if (rhs_horiz_6 < n_size) {                        \
198         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
199         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
200         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
201         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
202         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
203         rhs_pf5 = rhs(rhs_vert, rhs_horiz_5);                   \
204         rhs_pf6 = rhs(rhs_vert, rhs_horiz_6);                   \
205       } else if (rhs_horiz_5 < n_size) {                        \
206         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
207         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
208         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
209         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
210         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
211         rhs_pf5 = rhs(rhs_vert, rhs_horiz_5);                   \
212       } else if (rhs_horiz_4 < n_size) {                        \
213         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
214         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
215         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
216         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
217         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
218       } else if (rhs_horiz_3 < n_size) {                        \
219         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
220         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
221         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
222         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
223       } else if (rhs_horiz_2 < n_size) {                        \
224         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
225         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
226         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
227       } else if (rhs_horiz_1 < n_size) {                        \
228         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
229         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
230       } else if (rhs_horiz_0 < n_size) {                        \
231         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
232       }                                                         \
233     }                                                           \
234   }                                                             \
235 
236 #define writeRegToShmem(_)                      \
237   lhs_shmem[lhs_store_idx_0] = lhs_pf0;         \
238   rhs_shmem[rhs_store_idx_0] = rhs_pf0;         \
239                                                 \
240   lhs_shmem[lhs_store_idx_1] = lhs_pf1;         \
241   rhs_shmem[rhs_store_idx_1] = rhs_pf1;         \
242                                                 \
243   lhs_shmem[lhs_store_idx_2] = lhs_pf2;         \
244   rhs_shmem[rhs_store_idx_2] = rhs_pf2;         \
245                                                 \
246   lhs_shmem[lhs_store_idx_3] = lhs_pf3;         \
247   rhs_shmem[rhs_store_idx_3] = rhs_pf3;         \
248                                                 \
249   lhs_shmem[lhs_store_idx_4] = lhs_pf4;         \
250   rhs_shmem[rhs_store_idx_4] = rhs_pf4;         \
251                                                 \
252   lhs_shmem[lhs_store_idx_5] = lhs_pf5;         \
253   rhs_shmem[rhs_store_idx_5] = rhs_pf5;         \
254                                                 \
255   lhs_shmem[lhs_store_idx_6] = lhs_pf6;         \
256   rhs_shmem[rhs_store_idx_6] = rhs_pf6;         \
257                                                 \
258   lhs_shmem[lhs_store_idx_7] = lhs_pf7;         \
259   rhs_shmem[rhs_store_idx_7] = rhs_pf7;         \
260 
261   // declare and initialize result array
262 #define res(i, j) _res_##i##j
263 #define initResultRow(i)                        \
264   Scalar res(i, 0) = conv(0);                   \
265   Scalar res(i, 1) = conv(0);                   \
266   Scalar res(i, 2) = conv(0);                   \
267   Scalar res(i, 3) = conv(0);                   \
268   Scalar res(i, 4) = conv(0);                   \
269   Scalar res(i, 5) = conv(0);                   \
270   Scalar res(i, 6) = conv(0);                   \
271   Scalar res(i, 7) = conv(0);                   \
272 
273   internal::scalar_cast_op<int, Scalar> conv;
274   initResultRow(0);
275   initResultRow(1);
276   initResultRow(2);
277   initResultRow(3);
278   initResultRow(4);
279   initResultRow(5);
280   initResultRow(6);
281   initResultRow(7);
282 #undef initResultRow
283 
284   for (Index base_k = 0; base_k < k_size; base_k += 64) {
285     // wait for previous iteration to finish with shmem. Despite common sense,
286     // the code is a bit faster with this here then at bottom of loop
287     __syncthreads();
288 
289     prefetchIntoRegisters(base_k);
290     writeRegToShmem();
291 
292     #undef prefetchIntoRegisters
293     #undef writeRegToShmem
294 
295     // wait for shared mem packing to be done before starting computation
296     __syncthreads();
297 
298     // compute 8x8 matrix product by outer product. This involves packing one column
299     // of LHS and one row of RHS into registers (takes 16 registers).
300 
301 #define lcol(i) _lcol##i
302     Scalar lcol(0);
303     Scalar lcol(1);
304     Scalar lcol(2);
305     Scalar lcol(3);
306     Scalar lcol(4);
307     Scalar lcol(5);
308     Scalar lcol(6);
309     Scalar lcol(7);
310 
311 #define rrow(j) _rrow##j
312     Scalar rrow(0);
313     Scalar rrow(1);
314     Scalar rrow(2);
315     Scalar rrow(3);
316     Scalar rrow(4);
317     Scalar rrow(5);
318     Scalar rrow(6);
319     Scalar rrow(7);
320 
321     // Now x corresponds to k, y to m, and z to n
322     const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
323     const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
324 
325 #define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))]
326 #define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))]
327 
328 #define loadData(i, j)                          \
329     lcol(0) = lhs_element(0, j);               \
330     rrow(0) = rhs_element(i, 0);               \
331     lcol(1) = lhs_element(1, j);               \
332     rrow(1) = rhs_element(i, 1);               \
333     lcol(2) = lhs_element(2, j);               \
334     rrow(2) = rhs_element(i, 2);               \
335     lcol(3) = lhs_element(3, j);               \
336     rrow(3) = rhs_element(i, 3);               \
337     lcol(4) = lhs_element(4, j);               \
338     rrow(4) = rhs_element(i, 4);               \
339     lcol(5) = lhs_element(5, j);               \
340     rrow(5) = rhs_element(i, 5);               \
341     lcol(6) = lhs_element(6, j);               \
342     rrow(6) = rhs_element(i, 6);               \
343     lcol(7) = lhs_element(7, j);               \
344     rrow(7) = rhs_element(i, 7);               \
345 
346 #define computeCol(j)                           \
347     res(0, j) += lcol(0) * rrow(j);             \
348     res(1, j) += lcol(1) * rrow(j);             \
349     res(2, j) += lcol(2) * rrow(j);             \
350     res(3, j) += lcol(3) * rrow(j);             \
351     res(4, j) += lcol(4) * rrow(j);             \
352     res(5, j) += lcol(5) * rrow(j);             \
353     res(6, j) += lcol(6) * rrow(j);             \
354     res(7, j) += lcol(7) * rrow(j);             \
355 
356 #define computePass(i)                          \
357     loadData(i, i);                             \
358                                                 \
359     computeCol(0);                              \
360     computeCol(1);                              \
361     computeCol(2);                              \
362     computeCol(3);                              \
363     computeCol(4);                              \
364     computeCol(5);                              \
365     computeCol(6);                              \
366     computeCol(7);                              \
367 
368     computePass(0);
369     computePass(1);
370     computePass(2);
371     computePass(3);
372     computePass(4);
373     computePass(5);
374     computePass(6);
375     computePass(7);
376 
377 #undef lcol
378 #undef rrow
379 #undef lhs_element
380 #undef rhs_element
381 #undef loadData
382 #undef computeCol
383 #undef computePass
384   } // end loop over k
385 
386   // we've now iterated over all of the large (ie width 64) k blocks and
387   // accumulated results in registers. At this point thread (x, y, z) contains
388   // the sum across all big k blocks of the product of little k block of index (x, y)
389   // with block of index (y, z). To compute the final output, we need to reduce
390   // the 8 threads over y by summation.
391 #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
392 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
393 #else
394 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor_sync(0xFFFFFFFF, res(i, j), mask)
395 #endif
396 
397 #define reduceRow(i, mask)                      \
398   shuffleInc(i, 0, mask);                       \
399   shuffleInc(i, 1, mask);                       \
400   shuffleInc(i, 2, mask);                       \
401   shuffleInc(i, 3, mask);                       \
402   shuffleInc(i, 4, mask);                       \
403   shuffleInc(i, 5, mask);                       \
404   shuffleInc(i, 6, mask);                       \
405   shuffleInc(i, 7, mask);                       \
406 
407 #define reduceMatrix(mask)                      \
408   reduceRow(0, mask);                           \
409   reduceRow(1, mask);                           \
410   reduceRow(2, mask);                           \
411   reduceRow(3, mask);                           \
412   reduceRow(4, mask);                           \
413   reduceRow(5, mask);                           \
414   reduceRow(6, mask);                           \
415   reduceRow(7, mask);                           \
416 
417   // actually perform the reduction, now each thread of index (_, y, z)
418   // contains the correct values in its registers that belong in the output
419   // block
420   reduceMatrix(1);
421   reduceMatrix(2);
422   reduceMatrix(4);
423 
424 #undef shuffleInc
425 #undef reduceRow
426 #undef reduceMatrix
427 
428   // now we need to copy the 64 values into main memory. We can't split work
429   // among threads because all variables are in registers. There's 2 ways
430   // to do this:
431   // (1) have 1 thread do 64 writes from registers into global memory
432   // (2) have 1 thread do 64 writes into shared memory, and then 8 threads
433   //     each do 8 writes into global memory. We can just overwrite the shared
434   //     memory from the problem we just solved.
435   // (2) is slightly faster than (1) due to less branching and more ILP
436 
437   // TODO: won't yield much gain, but could just use currently unused shared mem
438   //       and then we won't have to sync
439   // wait for shared mem to be out of use
440   __syncthreads();
441 
442 #define writeResultShmem(i, j)                                          \
443   lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j); \
444 
445 #define writeRow(i)                             \
446   writeResultShmem(i, 0);                       \
447   writeResultShmem(i, 1);                       \
448   writeResultShmem(i, 2);                       \
449   writeResultShmem(i, 3);                       \
450   writeResultShmem(i, 4);                       \
451   writeResultShmem(i, 5);                       \
452   writeResultShmem(i, 6);                       \
453   writeResultShmem(i, 7);                       \
454 
455   if (threadIdx.x == 0) {
456     writeRow(0);
457     writeRow(1);
458     writeRow(2);
459     writeRow(3);
460     writeRow(4);
461     writeRow(5);
462     writeRow(6);
463     writeRow(7);
464   }
465 #undef writeResultShmem
466 #undef writeRow
467 
468   const int max_i_write = numext::mini((int)((m_size - base_m - threadIdx.y + 7) / 8), 8);
469   const int max_j_write = numext::mini((int)((n_size - base_n - threadIdx.z + 7) / 8), 8);
470 
471   if (threadIdx.x < max_i_write) {
472     if (max_j_write == 8) {
473       // TODO: can i trade bank conflicts for coalesced writes?
474       Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0];
475       Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1];
476       Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2];
477       Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3];
478       Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4];
479       Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5];
480       Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6];
481       Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7];
482 
483       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0;
484       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1;
485       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2;
486       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3;
487       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4;
488       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5;
489       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6;
490       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7;
491     } else {
492 #pragma unroll 7
493       for (int j = 0; j < max_j_write; j++) {
494         Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j];
495         output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val;
496       }
497     }
498   }
499 #undef res
500 }
501 
502 
503 template<typename Scalar, typename Index, typename LhsMapper,
504          typename RhsMapper, typename OutputMapper>
505 __global__ void
506 #if defined(EIGEN_HIPCC)
507 __launch_bounds__(512, 1)
508 #else
509 __launch_bounds__(512)
510 #endif
EigenContractionKernel(const LhsMapper lhs,const RhsMapper rhs,const OutputMapper output,const Index m_size,const Index n_size,const Index k_size)511 EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
512                        const OutputMapper output,
513                        const Index m_size, const Index n_size, const Index k_size) {
514   __shared__ Scalar lhs_shmem[72 * 64];
515   __shared__ Scalar rhs_shmem[72 * 64];
516 
517   const Index m_block_idx = blockIdx.x;
518   const Index n_block_idx = blockIdx.y;
519 
520   const Index base_m = 64 * m_block_idx;
521   const Index base_n = 64 * n_block_idx;
522 
523   if (base_m + 63 < m_size && base_n + 63 < n_size) {
524     EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
525   } else {
526     EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
527   }
528 }
529 
530 
531 template<typename Index, typename LhsMapper,
532          typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
533          bool CHECK_RHS_BOUNDARY>
534 __device__ __forceinline__ void
EigenFloatContractionKernelInternal16x16(const LhsMapper lhs,const RhsMapper rhs,const OutputMapper output,float2 lhs_shmem2[][16],float2 rhs_shmem2[][8],const Index m_size,const Index n_size,const Index k_size,const Index base_m,const Index base_n)535 EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rhs,
536                        const OutputMapper output, float2 lhs_shmem2[][16],
537                        float2 rhs_shmem2[][8], const Index m_size,
538                        const Index n_size, const Index k_size,
539                        const Index base_m, const Index base_n) {
540 
541   // prefetch registers
542   float4 lhs_pf0, rhs_pf0;
543 
544   float4 results[4];
545   for (int i=0; i < 4; i++) {
546     results[i].x = results[i].y = results[i].z = results[i].w = 0;
547   }
548 
549 #define prefetch_lhs(reg, row, col)                            \
550     if (!CHECK_LHS_BOUNDARY) {                                 \
551       if (col < k_size) {                                      \
552         reg =lhs.template loadPacket<float4,Unaligned>(row, col);     \
553       }                                                        \
554     } else {                                                   \
555       if (col < k_size) {                                      \
556         if (row + 3 < m_size) {                                \
557           reg =lhs.template loadPacket<float4,Unaligned>(row, col);   \
558         } else if (row + 2 < m_size) {                         \
559           reg.x =lhs(row + 0, col);                            \
560           reg.y =lhs(row + 1, col);                            \
561           reg.z =lhs(row + 2, col);                            \
562         } else if (row + 1 < m_size) {                         \
563           reg.x =lhs(row + 0, col);                            \
564           reg.y =lhs(row + 1, col);                            \
565         } else if (row  < m_size) {                            \
566           reg.x =lhs(row + 0, col);                            \
567         }                                                      \
568       }                                                        \
569     }							       \
570 
571   Index lhs_vert = base_m+threadIdx.x*4;
572 
573   for (Index k = 0; k < k_size; k += 16) {
574 
575     lhs_pf0 = internal::pset1<float4>(0);
576     rhs_pf0 = internal::pset1<float4>(0);
577 
578     Index lhs_horiz = threadIdx.y+k;
579     prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
580 
581     Index rhs_vert = k+(threadIdx.x%4)*4;
582     Index rhs_horiz0 = (threadIdx.x>>2)+threadIdx.y*4+base_n;
583 
584     if (!CHECK_RHS_BOUNDARY) {
585       if ((rhs_vert + 3) < k_size) {
586         // just CHECK_RHS_BOUNDARY
587         rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
588       } else if (rhs_vert + 2 < k_size) {
589         // just CHECK_RHS_BOUNDARY
590         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
591         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
592         rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
593       } else if (rhs_vert + 1 < k_size) {
594         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
595         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
596       } else if (rhs_vert  < k_size) {
597         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
598       }
599     } else {
600       if (rhs_horiz0 < n_size) {
601         if ((rhs_vert + 3) < k_size) {
602           rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
603         } else if ((rhs_vert + 2) < k_size) {
604           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
605           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
606           rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
607         } else if ((rhs_vert + 1) < k_size) {
608           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
609           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
610         } else if (rhs_vert  < k_size) {
611           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
612         }
613       }
614     }
615     float x1, x2 ;
616     // the following can be a bitwise operation..... some day.
617     if((threadIdx.x%8) < 4) {
618       x1 = rhs_pf0.y;
619       x2 = rhs_pf0.w;
620     } else {
621       x1 = rhs_pf0.x;
622       x2 = rhs_pf0.z;
623     }
624     #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
625     x1 = __shfl_xor(x1, 4);
626     x2 = __shfl_xor(x2, 4);
627     #else
628     x1 = __shfl_xor_sync(0xFFFFFFFF, x1, 4);
629     x2 = __shfl_xor_sync(0xFFFFFFFF, x2, 4);
630     #endif
631     if((threadIdx.x%8) < 4) {
632       rhs_pf0.y = x1;
633       rhs_pf0.w = x2;
634     } else {
635       rhs_pf0.x = x1;
636       rhs_pf0.z = x2;
637     }
638 
639     // We have 64 features.
640     // Row 0 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 0, 1.
641     // Row 1 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 2, 3.
642     // ...
643     // Row 31 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 62, 63
644     // Row 32 -> times (2, 6, 10, 14, 3, 7, 11, 15) for features 0, 1
645     // ...
646     rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2][threadIdx.x%8] = make_float2(rhs_pf0.x, rhs_pf0.y);
647     rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2+32][threadIdx.x%8] = make_float2(rhs_pf0.z, rhs_pf0.w);
648 
649     // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61)
650     // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61)
651     // ...
652     // Row 15 (time 15) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61)
653     // Row 16 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), ..  (62, 63)
654     // ...
655 
656     lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y);
657     lhs_shmem2[threadIdx.y+16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w);
658 
659 
660 #define add_vals(fl1, fl2, fr1, fr2)\
661     results[0].x += fl1.x * fr1.x;\
662     results[0].y += fl1.y * fr1.x;\
663     results[0].z += fl2.x * fr1.x;\
664     results[0].w += fl2.y * fr1.x;\
665 \
666     results[1].x += fl1.x * fr1.y;\
667     results[1].y += fl1.y * fr1.y;\
668     results[1].z += fl2.x * fr1.y;\
669     results[1].w += fl2.y * fr1.y;\
670 \
671     results[2].x += fl1.x * fr2.x;\
672     results[2].y += fl1.y * fr2.x;\
673     results[2].z += fl2.x * fr2.x;\
674     results[2].w += fl2.y * fr2.x;\
675 \
676     results[3].x += fl1.x * fr2.y;\
677     results[3].y += fl1.y * fr2.y;\
678     results[3].z += fl2.x * fr2.y;\
679     results[3].w += fl2.y * fr2.y;\
680 
681     __syncthreads();
682 
683     // Do the multiplies.
684     #pragma unroll
685     for (int koff = 0; koff < 16; koff ++) {
686       // 32 x threads.
687       float2 fl1 = lhs_shmem2[koff][threadIdx.x];
688       float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x];
689 
690       int start_feature = threadIdx.y * 4;
691       float2 fr1 = rhs_shmem2[(start_feature>>1) + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
692       float2 fr2 = rhs_shmem2[(start_feature>>1) + 1 + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
693 
694       add_vals(fl1, fl2, fr1, fr2)
695     }
696     __syncthreads();
697   }
698 
699 #undef prefetch_lhs
700 #undef add_vals
701 
702   Index horiz_base = threadIdx.y*4+base_n;
703   if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
704     for (int i = 0; i < 4; i++) {
705       output(lhs_vert, horiz_base + i) = results[i].x;
706       output(lhs_vert + 1, horiz_base + i) = results[i].y;
707       output(lhs_vert + 2, horiz_base + i) = results[i].z;
708       output(lhs_vert + 3, horiz_base + i) = results[i].w;
709     }
710   } else if (!CHECK_RHS_BOUNDARY) {
711     // CHECK LHS
712     if (lhs_vert + 3 < m_size) {
713       for (int i = 0; i < 4; i++) {
714         output(lhs_vert, horiz_base + i) = results[i].x;
715         output(lhs_vert + 1, horiz_base + i) = results[i].y;
716         output(lhs_vert + 2, horiz_base + i) = results[i].z;
717         output(lhs_vert + 3, horiz_base + i) = results[i].w;
718       }
719     } else if (lhs_vert + 2 < m_size) {
720       for (int i = 0; i < 4; i++) {
721         output(lhs_vert, horiz_base + i) = results[i].x;
722         output(lhs_vert + 1, horiz_base + i) = results[i].y;
723         output(lhs_vert + 2, horiz_base + i) = results[i].z;
724       }
725     } else if (lhs_vert + 1 < m_size) {
726       for (int i = 0; i < 4; i++) {
727         output(lhs_vert, horiz_base + i) = results[i].x;
728         output(lhs_vert + 1, horiz_base + i) = results[i].y;
729       }
730     } else if (lhs_vert  < m_size) {
731       for (int i = 0; i < 4; i++) {
732         output(lhs_vert, horiz_base + i) = results[i].x;
733       }
734     }
735   } else if (!CHECK_LHS_BOUNDARY) {
736     // CHECK RHS
737     /*
738     int ncols_rem = fminf(n_size- horiz_base, 4);
739     for (int i = 0; i < ncols_rem; i++) {
740       output(lhs_vert, horiz_base + i) = results[i].x;
741       output(lhs_vert + 1, horiz_base + i) = results[i].y;
742       output(lhs_vert + 2, horiz_base + i) = results[i].z;
743       output(lhs_vert + 3, horiz_base + i) = results[i].w;
744     }*/
745     for (int i = 0; i < 4; i++) {
746       if (horiz_base+i < n_size) {
747         output(lhs_vert, horiz_base + i) = results[i].x;
748         output(lhs_vert + 1, horiz_base + i) = results[i].y;
749         output(lhs_vert + 2, horiz_base + i) = results[i].z;
750         output(lhs_vert + 3, horiz_base + i) = results[i].w;
751        }
752     }
753   } else {
754     // CHECK both boundaries.
755     for (int i = 0; i < 4; i++) {
756       if (horiz_base+i < n_size) {
757         if (lhs_vert < m_size)
758           output(lhs_vert, horiz_base + i) = results[i].x;
759         if (lhs_vert + 1 < m_size)
760           output(lhs_vert + 1, horiz_base + i) = results[i].y;
761         if (lhs_vert + 2 < m_size)
762           output(lhs_vert + 2, horiz_base + i) = results[i].z;
763         if (lhs_vert + 3 < m_size)
764           output(lhs_vert + 3, horiz_base + i) = results[i].w;
765       }
766     }
767   }
768 }
769 
770 
771 template<typename Index, typename LhsMapper,
772          typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
773          bool CHECK_RHS_BOUNDARY>
774 __device__ __forceinline__ void
EigenFloatContractionKernelInternal(const LhsMapper lhs,const RhsMapper rhs,const OutputMapper output,float2 lhs_shmem2[][32],float2 rhs_shmem2[][8],const Index m_size,const Index n_size,const Index k_size,const Index base_m,const Index base_n)775 EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
776                        const OutputMapper output, float2 lhs_shmem2[][32],
777                        float2 rhs_shmem2[][8], const Index m_size,
778                        const Index n_size, const Index k_size,
779                        const Index base_m, const Index base_n) {
780 
781   // prefetch registers
782   float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3;
783   float4 rhs_pf0, rhs_pf1;
784 
785   float4 results[8];
786   for (int i=0; i < 8; i++) {
787     results[i].x = results[i].y = results[i].z = results[i].w = 0;
788   }
789 
790   Index lhs_vert = base_m+threadIdx.x*4+(threadIdx.y%4)*32;
791   for (Index k = 0; k < k_size; k += 32) {
792     lhs_pf0 = internal::pset1<float4>(0);
793     lhs_pf1 = internal::pset1<float4>(0);
794     lhs_pf2 = internal::pset1<float4>(0);
795     lhs_pf3 = internal::pset1<float4>(0);
796 
797     rhs_pf0 = internal::pset1<float4>(0);
798     rhs_pf1 = internal::pset1<float4>(0);
799 
800      if (!CHECK_LHS_BOUNDARY) {
801       if ((threadIdx.y/4+k+24) < k_size) {
802         lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
803         lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
804         lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
805         lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
806       } else if ((threadIdx.y/4+k+16) < k_size) {
807         lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
808         lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
809         lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
810       } else if ((threadIdx.y/4+k+8) < k_size) {
811         lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
812         lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
813       } else if ((threadIdx.y/4+k) < k_size) {
814         lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
815       }
816     } else {
817       // just CHECK_LHS_BOUNDARY
818       if (lhs_vert + 3 < m_size) {
819         if ((threadIdx.y/4+k+24) < k_size) {
820           lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
821           lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
822           lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
823           lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
824         } else if ((threadIdx.y/4+k+16) < k_size) {
825           lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
826           lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
827           lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
828         } else if ((threadIdx.y/4+k+8) < k_size) {
829           lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
830           lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
831         } else if ((threadIdx.y/4+k) < k_size) {
832           lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
833         }
834       } else if (lhs_vert + 2 < m_size) {
835         if ((threadIdx.y/4+k+24) < k_size) {
836           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
837           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
838           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
839           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
840           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
841           lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
842           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
843           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
844           lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
845           lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
846           lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
847           lhs_pf3.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+24));
848         } else if ((threadIdx.y/4+k+16) < k_size) {
849           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
850           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
851           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
852           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
853           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
854           lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
855           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
856           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
857           lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
858         } else if ((threadIdx.y/4+k+8) < k_size) {
859           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
860           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
861           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
862           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
863           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
864           lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
865         } else if ((threadIdx.y/4+k) < k_size) {
866           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
867           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
868           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
869         }
870       } else if (lhs_vert + 1 < m_size) {
871         if ((threadIdx.y/4+k+24) < k_size) {
872           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
873           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
874           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
875           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
876           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
877           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
878           lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
879           lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
880         } else if ((threadIdx.y/4+k+16) < k_size) {
881           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
882           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
883           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
884           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
885           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
886           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
887         } else if ((threadIdx.y/4+k+8) < k_size) {
888           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
889           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
890           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
891           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
892         } else if ((threadIdx.y/4+k) < k_size) {
893           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
894           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
895         }
896       } else if (lhs_vert < m_size) {
897         if ((threadIdx.y/4+k+24) < k_size) {
898           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
899           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
900           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
901           lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
902         } else if ((threadIdx.y/4+k+16) < k_size) {
903           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
904           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
905           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
906         } else if ((threadIdx.y/4+k+8) < k_size) {
907           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
908           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
909         } else if ((threadIdx.y/4+k) < k_size) {
910           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
911         }
912       }
913     }
914     __syncthreads();
915     Index rhs_vert = k+threadIdx.x*4;
916     Index rhs_horiz0 = threadIdx.y*2+base_n;
917     Index rhs_horiz1 = threadIdx.y*2+1+base_n;
918     if (!CHECK_RHS_BOUNDARY) {
919       if ((rhs_vert + 3) < k_size) {
920         // just CHECK_RHS_BOUNDARY
921         rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
922         rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
923       } else if (rhs_vert + 2 < k_size) {
924         // just CHECK_RHS_BOUNDARY
925         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
926         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
927         rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
928         rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
929         rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
930         rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
931       } else if (rhs_vert + 1 < k_size) {
932         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
933         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
934         rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
935         rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
936       } else if (rhs_vert  < k_size) {
937         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
938         rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
939       }
940     } else {
941       if (rhs_horiz1 < n_size) {
942         if ((rhs_vert + 3) < k_size) {
943           // just CHECK_RHS_BOUNDARY
944           rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
945           rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
946         } else if (rhs_vert + 2 < k_size) {
947           // just CHECK_RHS_BOUNDARY
948           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
949           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
950           rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
951           rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
952           rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
953           rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
954         } else if (k+threadIdx.x*4 + 1 < k_size) {
955           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
956           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
957           rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
958           rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
959         } else if (k+threadIdx.x*4  < k_size) {
960           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
961           rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
962         }
963       } else if (rhs_horiz0 < n_size) {
964         if ((rhs_vert + 3) < k_size) {
965           // just CHECK_RHS_BOUNDARY
966           rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
967         } else if ((rhs_vert + 2) < k_size) {
968           // just CHECK_RHS_BOUNDARY
969           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
970           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
971           rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
972         } else if ((rhs_vert + 1) < k_size) {
973           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
974           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
975         } else if (rhs_vert  < k_size) {
976           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
977         }
978       }
979     }
980     __syncthreads();
981     // Loaded. Do computation
982     // Row 0 -> times (0, 4, 8, .. 28) for features 0, 1.
983     // Row 1 -> times (0, 4, 8, .. 28) for features 2, 3.
984     // ..
985     // Row 31 -> times (0, 4, 8, .. 28) for features 62, 63
986     rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x);
987     // Row 32 -> times (1, 5, 9, .. 29) for features 0, 1.
988     // Row 33 -> times (1, 5, 9, .. 29) for features 2, 3.
989     // ..
990     rhs_shmem2[threadIdx.y+32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y);
991     // Row 64 -> times (2, 6, 10, .. 30) for features 0, 1.
992     // Row 65 -> times (2, 6, 10, .. 30) for features 2, 3.
993     rhs_shmem2[threadIdx.y+64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z);
994     // Row 96 -> times (3, 7, 11, .. 31) for features 0, 1.
995     // Row 97 -> times (3, 7, 11, .. 31) for features 2, 3.
996     rhs_shmem2[threadIdx.y+96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w);
997 
998     // LHS.
999     // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61) .. (124, 125)
1000     // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61) .. (124, 125)
1001     // ...
1002     // Row 8 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), ..  (62, 63) .. (126, 127)
1003     // Row 15 (time 7) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), ..  (62, 63) .. (126, 127)
1004 
1005 
1006 #define add_vals(a_feat1, a_feat2, f1, f2, f3, f4)\
1007       results[0].x += a_feat1.x * f1.x;\
1008       results[1].x += a_feat1.x * f1.y;\
1009       results[2].x += a_feat1.x * f2.x;\
1010       results[3].x += a_feat1.x * f2.y;\
1011       results[4].x += a_feat1.x * f3.x;\
1012       results[5].x += a_feat1.x * f3.y;\
1013       results[6].x += a_feat1.x * f4.x;\
1014       results[7].x += a_feat1.x * f4.y;\
1015 \
1016       results[0].y += a_feat1.y * f1.x;\
1017       results[1].y += a_feat1.y * f1.y;\
1018       results[2].y += a_feat1.y * f2.x;\
1019       results[3].y += a_feat1.y * f2.y;\
1020       results[4].y += a_feat1.y * f3.x;\
1021       results[5].y += a_feat1.y * f3.y;\
1022       results[6].y += a_feat1.y * f4.x;\
1023       results[7].y += a_feat1.y * f4.y;\
1024 \
1025       results[0].z += a_feat2.x * f1.x;\
1026       results[1].z += a_feat2.x * f1.y;\
1027       results[2].z += a_feat2.x * f2.x;\
1028       results[3].z += a_feat2.x * f2.y;\
1029       results[4].z += a_feat2.x * f3.x;\
1030       results[5].z += a_feat2.x * f3.y;\
1031       results[6].z += a_feat2.x * f4.x;\
1032       results[7].z += a_feat2.x * f4.y;\
1033 \
1034       results[0].w += a_feat2.y * f1.x;\
1035       results[1].w += a_feat2.y * f1.y;\
1036       results[2].w += a_feat2.y * f2.x;\
1037       results[3].w += a_feat2.y * f2.y;\
1038       results[4].w += a_feat2.y * f3.x;\
1039       results[5].w += a_feat2.y * f3.y;\
1040       results[6].w += a_feat2.y * f4.x;\
1041       results[7].w += a_feat2.y * f4.y;\
1042 
1043     lhs_shmem2[threadIdx.y/4][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.x, lhs_pf0.y);
1044     lhs_shmem2[threadIdx.y/4+8][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.x, lhs_pf1.y);
1045     lhs_shmem2[threadIdx.y/4+16][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.x, lhs_pf2.y);
1046     lhs_shmem2[threadIdx.y/4+24][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.x, lhs_pf3.y);
1047 
1048     lhs_shmem2[threadIdx.y/4 + 32][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.z, lhs_pf0.w);
1049     lhs_shmem2[threadIdx.y/4 + 40][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.z, lhs_pf1.w);
1050     lhs_shmem2[threadIdx.y/4 + 48][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.z, lhs_pf2.w);
1051     lhs_shmem2[threadIdx.y/4 + 56][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.z, lhs_pf3.w);
1052 
1053     __syncthreads();
1054 
1055     // Do the multiplies.
1056     #pragma unroll
1057     for (int koff = 0; koff < 32; koff ++) {
1058       float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8];
1059       float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8];
1060 
1061       // first feature is at (threadIdx.y/4) * 8 last is at start + 8.
1062       int start_feature = (threadIdx.y / 4) * 8;
1063 
1064       float2 br1 = rhs_shmem2[start_feature/2 +     (koff % 4) * 32][koff/4];
1065       float2 br2 = rhs_shmem2[start_feature/2 + 1 + (koff % 4) * 32][koff/4];
1066       float2 br3 = rhs_shmem2[start_feature/2 + 2 + (koff % 4) * 32][koff/4];
1067       float2 br4 = rhs_shmem2[start_feature/2 + 3 + (koff % 4) * 32][koff/4];
1068 
1069       add_vals(a3, a4, br1, br2, br3, br4)
1070     }
1071     __syncthreads();
1072   } // end loop over k
1073 
1074   __syncthreads();
1075   Index horiz_base = (threadIdx.y/4)*8+base_n;
1076   if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
1077     for (int i = 0; i < 8; i++) {
1078       output(lhs_vert, horiz_base + i) = results[i].x;
1079       output(lhs_vert + 1, horiz_base + i) = results[i].y;
1080       output(lhs_vert + 2, horiz_base + i) = results[i].z;
1081       output(lhs_vert + 3, horiz_base + i) = results[i].w;
1082     }
1083   } else if (!CHECK_RHS_BOUNDARY) {
1084     if (lhs_vert + 3 < m_size) {
1085       for (int i = 0; i < 8; i++) {
1086         output(lhs_vert, horiz_base + i) = results[i].x;
1087         output(lhs_vert + 1, horiz_base + i) = results[i].y;
1088         output(lhs_vert + 2, horiz_base + i) = results[i].z;
1089         output(lhs_vert + 3, horiz_base + i) = results[i].w;
1090       }
1091     } else if (lhs_vert + 2 < m_size) {
1092       for (int i = 0; i < 8; i++) {
1093         output(lhs_vert, horiz_base + i) = results[i].x;
1094         output(lhs_vert + 1, horiz_base + i) = results[i].y;
1095         output(lhs_vert + 2, horiz_base + i) = results[i].z;
1096       }
1097     } else if (lhs_vert + 1 < m_size) {
1098       for (int i = 0; i < 8; i++) {
1099         output(lhs_vert, horiz_base + i) = results[i].x;
1100         output(lhs_vert + 1, horiz_base + i) = results[i].y;
1101       }
1102     } else if (lhs_vert  < m_size) {
1103       for (int i = 0; i < 8; i++) {
1104         output(lhs_vert, horiz_base + i) = results[i].x;
1105       }
1106     }
1107   } else if (!CHECK_LHS_BOUNDARY) {
1108     // CHECK BOUNDARY_B
1109     for (int i = 0; i < 8; i++) {
1110       if (horiz_base + i < n_size) {
1111         output(lhs_vert, horiz_base + i) = results[i].x;
1112         output(lhs_vert + 1, horiz_base + i) = results[i].y;
1113         output(lhs_vert + 2, horiz_base + i) = results[i].z;
1114         output(lhs_vert + 3, horiz_base + i) = results[i].w;
1115       }
1116     }
1117   } else {
1118     // CHECK both boundaries.
1119     for (int i = 0; i < 8; i++) {
1120       if (horiz_base + i < n_size) {
1121         if (lhs_vert < m_size)
1122           output(lhs_vert, horiz_base + i) = results[i].x;
1123         if (lhs_vert + 1 < m_size)
1124           output(lhs_vert + 1, horiz_base + i) = results[i].y;
1125         if (lhs_vert + 2 < m_size)
1126           output(lhs_vert + 2, horiz_base + i) = results[i].z;
1127         if (lhs_vert + 3 < m_size)
1128           output(lhs_vert + 3, horiz_base + i) = results[i].w;
1129       }
1130     }
1131   }
1132 }
1133 
1134 
1135 template<typename Index, typename LhsMapper,
1136          typename RhsMapper, typename OutputMapper>
1137 __global__ void
1138 #if defined(EIGEN_HIPCC)
1139 __launch_bounds__(256, 1)
1140 #else
1141 __launch_bounds__(256)
1142 #endif
EigenFloatContractionKernel(const LhsMapper lhs,const RhsMapper rhs,const OutputMapper output,const Index m_size,const Index n_size,const Index k_size)1143 EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
1144                        const OutputMapper output,
1145                        const Index m_size, const Index n_size, const Index k_size) {
1146   __shared__ float2 lhs_shmem[64*32];
1147   __shared__ float2 rhs_shmem[128*8];
1148 
1149   typedef float2 LHS_MEM[64][32];
1150   typedef float2 RHS_MEM[128][8];
1151 
1152   const Index m_block_idx = blockIdx.x;
1153   const Index n_block_idx = blockIdx.y;
1154 
1155   const Index base_m = 128 * m_block_idx;
1156   const Index base_n = 64 * n_block_idx;
1157 
1158   bool check_rhs = (base_n + 63) >= n_size;
1159   bool check_lhs128 = (base_m + 127) >= m_size;
1160 
1161   if (!check_rhs) {
1162     if (!check_lhs128) {
1163       // >= 128 rows left
1164       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
1165                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1166     } else {
1167       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
1168                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1169     }
1170   } else {
1171     if (!check_lhs128) {
1172       // >= 128 rows left
1173       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
1174                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1175     } else {
1176       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
1177                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1178     }
1179   }
1180 }
1181 
1182 template<typename Index, typename LhsMapper,
1183          typename RhsMapper, typename OutputMapper>
1184 __global__ void
1185 #if defined(EIGEN_HIPCC)
1186 __launch_bounds__(256, 1)
1187 #else
1188 __launch_bounds__(256)
1189 #endif
EigenFloatContractionKernel16x16(const LhsMapper lhs,const RhsMapper rhs,const OutputMapper output,const Index m_size,const Index n_size,const Index k_size)1190 EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs,
1191                        const OutputMapper output,
1192                        const Index m_size, const Index n_size, const Index k_size) {
1193   __shared__ float2 lhs_shmem[32][16];
1194   __shared__ float2 rhs_shmem[64][8];
1195 
1196   const Index m_block_idx = blockIdx.x;
1197   const Index n_block_idx = blockIdx.y;
1198 
1199   const Index base_m = 64 * m_block_idx;
1200   const Index base_n = 64 * n_block_idx;
1201 
1202   if (base_m + 63 < m_size) {
1203     if (base_n + 63 < n_size) {
1204       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1205     } else {
1206       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1207     }
1208   } else {
1209     if (base_n + 63 < n_size) {
1210       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1211     } else {
1212       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1213     }
1214   }
1215 }
1216 
1217 
1218 template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
1219 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> :
1220     public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> > {
1221 
1222   typedef GpuDevice Device;
1223 
1224   typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
1225   typedef TensorContractionEvaluatorBase<Self> Base;
1226 
1227   typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
1228   typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
1229   typedef typename XprType::Index Index;
1230   typedef typename XprType::CoeffReturnType CoeffReturnType;
1231   typedef typename PacketType<CoeffReturnType, GpuDevice>::type PacketReturnType;
1232 
1233   enum {
1234     Layout = TensorEvaluator<LeftArgType, Device>::Layout,
1235   };
1236 
1237   // Most of the code is assuming that both input tensors are ColMajor. If the
1238   // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
1239   // If we want to compute A * B = C, where A is LHS and B is RHS, the code
1240   // will pretend B is LHS and A is RHS.
1241   typedef typename internal::conditional<
1242     static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
1243   typedef typename internal::conditional<
1244     static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
1245 
1246   static const int LDims =
1247       internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
1248   static const int RDims =
1249       internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
1250   static const int ContractDims = internal::array_size<Indices>::value;
1251 
1252   typedef array<Index, LDims> left_dim_mapper_t;
1253   typedef array<Index, RDims> right_dim_mapper_t;
1254 
1255   typedef array<Index, ContractDims> contract_t;
1256   typedef array<Index, LDims - ContractDims> left_nocontract_t;
1257   typedef array<Index, RDims - ContractDims> right_nocontract_t;
1258 
1259   static const int NumDims = LDims + RDims - 2 * ContractDims;
1260 
1261   typedef DSizes<Index, NumDims> Dimensions;
1262 
1263   // typedefs needed in evalTo
1264   typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
1265   typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
1266 
1267   typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
1268   typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
1269 
1270   typedef typename LeftEvaluator::Dimensions LeftDimensions;
1271   typedef typename RightEvaluator::Dimensions RightDimensions;
1272 
1273   TensorEvaluator(const XprType& op, const Device& device) :
1274       Base(op, device)
1275   {
1276     EIGEN_STATIC_ASSERT( (internal::is_same<OutputKernelType, const NoOpOutputKernel>::value),
1277                           GPU_TENSOR_CONTRACTION_DOES_NOT_SUPPORT_OUTPUT_KERNELS);
1278   }
1279 
1280   // We need to redefine this method to make nvcc happy
1281   EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
1282     this->m_leftImpl.evalSubExprsIfNeeded(NULL);
1283     this->m_rightImpl.evalSubExprsIfNeeded(NULL);
1284     if (data) {
1285       evalTo(data);
1286       return false;
1287     } else {
1288       this->m_result = static_cast<Scalar *>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar)));
1289       evalTo(this->m_result);
1290       return true;
1291     }
1292   }
1293 
1294   void evalTo(Scalar* buffer) const {
1295     if (this->m_lhs_inner_dim_contiguous) {
1296       if (this->m_rhs_inner_dim_contiguous) {
1297         if (this->m_rhs_inner_dim_reordered) {
1298           evalTyped<true, true, true, Unaligned>(buffer);
1299         }
1300         else {
1301           evalTyped<true, true, false, Unaligned>(buffer);
1302         }
1303       }
1304       else {
1305        if (this->m_rhs_inner_dim_reordered) {
1306           evalTyped<true, false, true, Unaligned>(buffer);
1307         }
1308         else {
1309           evalTyped<true, false, false, Unaligned>(buffer);
1310         }
1311       }
1312     }
1313     else {
1314       if (this->m_rhs_inner_dim_contiguous) {
1315         if (this->m_rhs_inner_dim_reordered) {
1316           evalTyped<false, true, true, Unaligned>(buffer);
1317         }
1318         else {
1319           evalTyped<false, true, false, Unaligned>(buffer);
1320         }
1321       }
1322       else {
1323        if (this->m_rhs_inner_dim_reordered) {
1324           evalTyped<false, false, true, Unaligned>(buffer);
1325         }
1326         else {
1327           evalTyped<false, false, false, Unaligned>(buffer);
1328         }
1329       }
1330     }
1331   }
1332 
1333   template <typename LhsScalar, typename RhsScalar, typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels {
1334     static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
1335     const Index m_blocks = (m + 63) / 64;
1336     const Index n_blocks = (n + 63) / 64;
1337     const dim3 num_blocks(m_blocks, n_blocks, 1);
1338     const dim3 block_size(8, 8, 8);
1339     LAUNCH_GPU_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1340     }
1341   };
1342 
1343   template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels<float, float, Index, LhsMapper, RhsMapper, OutputMapper> {
1344     static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
1345       if (m < 768 || n < 768) {
1346         const Index m_blocks = (m + 63) / 64;
1347         const Index n_blocks = (n + 63) / 64;
1348         const dim3 num_blocks(m_blocks, n_blocks, 1);
1349         const dim3 block_size(16, 16, 1);
1350         LAUNCH_GPU_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1351       } else {
1352         const Index m_blocks = (m + 127) / 128;
1353         const Index n_blocks = (n + 63) / 64;
1354         const dim3 num_blocks(m_blocks, n_blocks, 1);
1355         const dim3 block_size(8, 32, 1);
1356         LAUNCH_GPU_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1357       }
1358     }
1359   };
1360 
1361   template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
1362   void evalTyped(Scalar* buffer) const {
1363     // columns in left side, rows in right side
1364     const Index k = this->m_k_size;
1365     EIGEN_UNUSED_VARIABLE(k)
1366 
1367     // rows in left side
1368     const Index m = this->m_i_size;
1369 
1370     // columns in right side
1371     const Index n = this->m_j_size;
1372 
1373     // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
1374     this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
1375 
1376     typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
1377                                                    LeftEvaluator, left_nocontract_t,
1378                                                    contract_t, 4,
1379                                                    lhs_inner_dim_contiguous,
1380                                                    false, Unaligned> LhsMapper;
1381 
1382     typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
1383                                                    RightEvaluator, right_nocontract_t,
1384                                                    contract_t, 4,
1385                                                    rhs_inner_dim_contiguous,
1386                                                    rhs_inner_dim_reordered, Unaligned> RhsMapper;
1387 
1388     typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1389 
1390 
1391     // initialize data mappers
1392     LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
1393                   this->m_left_contracting_strides, this->m_k_strides);
1394 
1395     RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
1396                   this->m_right_contracting_strides, this->m_k_strides);
1397 
1398     OutputMapper output(buffer, m);
1399 
1400 #if defined(EIGEN_USE_HIP)
1401     setGpuSharedMemConfig(hipSharedMemBankSizeEightByte);
1402 #else
1403     setGpuSharedMemConfig(cudaSharedMemBankSizeEightByte);
1404 #endif
1405 
1406     LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output,  m, n, k, this->m_device);
1407   }
1408 };
1409 
1410 } // end namespace Eigen
1411 
1412 #endif // EIGEN_USE_GPU and EIGEN_GPUCC
1413 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
1414