1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014-2015 Benoit Steiner <[email protected]>
5 // Copyright (C) 2015 Navdeep Jaitly <[email protected]>
6 // Copyright (C) 2014 Eric Martin <[email protected]>
7 //
8 // This Source Code Form is subject to the terms of the Mozilla
9 // Public License v. 2.0. If a copy of the MPL was not distributed
10 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
11
12 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
13 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
14
15 #if defined(EIGEN_USE_GPU) && defined(EIGEN_GPUCC)
16
17 namespace Eigen {
18
19 template<typename Scalar, typename Index, typename LhsMapper,
20 typename RhsMapper, typename OutputMapper, bool needs_edge_check>
21 __device__ EIGEN_STRONG_INLINE void
EigenContractionKernelInternal(const LhsMapper lhs,const RhsMapper rhs,const OutputMapper output,Scalar * lhs_shmem,Scalar * rhs_shmem,const Index m_size,const Index n_size,const Index k_size)22 EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
23 const OutputMapper output, Scalar* lhs_shmem, Scalar* rhs_shmem,
24 const Index m_size, const Index n_size, const Index k_size) {
25
26 const Index m_block_idx = blockIdx.x;
27 const Index n_block_idx = blockIdx.y;
28
29 const Index base_m = 64 * m_block_idx;
30 const Index base_n = 64 * n_block_idx;
31
32 // declare and initialize 64 registers for output 8x8 block
33
34 // prefetch registers
35 Scalar lhs_pf0;
36 Scalar lhs_pf1;
37 Scalar lhs_pf2;
38 Scalar lhs_pf3;
39 Scalar lhs_pf4;
40 Scalar lhs_pf5;
41 Scalar lhs_pf6;
42 Scalar lhs_pf7;
43
44 Scalar rhs_pf0;
45 Scalar rhs_pf1;
46 Scalar rhs_pf2;
47 Scalar rhs_pf3;
48 Scalar rhs_pf4;
49 Scalar rhs_pf5;
50 Scalar rhs_pf6;
51 Scalar rhs_pf7;
52
53 // shared memory is formatted
54 // (contract idx in block, nocontract idx in block, block idx)
55 // where block idx is column major. This transposition limits the number of
56 // bank conflicts when reading the LHS. The core idea is that since the contracting
57 // index is shared by both sides, then the contracting index should be in threadIdx.x.
58
59 // On the LHS, we pad each row inside of each block with an extra element. This makes
60 // each block 8 rows of 9 elements, which is 72 elements. This gives no bank conflicts
61 // on writes and very few 2-way conflicts on reads. There is an 8x8 grid of these blocks.
62
63 // On the RHS we just add 8 padding elements to the end of each block. This gives no bank
64 // conflicts on writes and also none on reads.
65
66 // storage indices
67 const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
68 const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;
69
70 const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
71 const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
72 const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
73 const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
74 const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
75 const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
76 const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
77 const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;
78
79 const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
80 const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
81 const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
82 const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
83 const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
84 const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
85 const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
86 const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
87
88 // in the loading code, the following variables are important:
89 // threadIdx.x: the vertical position in an 8x8 block
90 // threadIdx.y: the vertical index of the 8x8 block in the grid
91 // threadIdx.z: the horizontal position in an 8x8 block
92 // k: the horizontal index of the 8x8 block in the grid
93 //
94 // The k parameter is implicit (it was the loop counter for a loop that went
95 // from 0 to <8, but now that loop is unrolled in the below code.
96
97 const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
98 const Index lhs_vert = base_m + load_idx_vert;
99
100 #define prefetchIntoRegisters(base_k) \
101 { \
102 lhs_pf0 = conv(0); \
103 lhs_pf1 = conv(0); \
104 lhs_pf2 = conv(0); \
105 lhs_pf3 = conv(0); \
106 lhs_pf4 = conv(0); \
107 lhs_pf5 = conv(0); \
108 lhs_pf6 = conv(0); \
109 lhs_pf7 = conv(0); \
110 \
111 rhs_pf0 = conv(0); \
112 rhs_pf1 = conv(0); \
113 rhs_pf2 = conv(0); \
114 rhs_pf3 = conv(0); \
115 rhs_pf4 = conv(0); \
116 rhs_pf5 = conv(0); \
117 rhs_pf6 = conv(0); \
118 rhs_pf7 = conv(0); \
119 \
120 if (!needs_edge_check || lhs_vert < m_size) { \
121 const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \
122 const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8; \
123 const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8; \
124 const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8; \
125 const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8; \
126 const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8; \
127 const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8; \
128 const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8; \
129 \
130 if (!needs_edge_check || lhs_horiz_7 < k_size) { \
131 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
132 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
133 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
134 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
135 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
136 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
137 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
138 lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \
139 } else if (lhs_horiz_6 < k_size) { \
140 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
141 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
142 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
143 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
144 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
145 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
146 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
147 } else if (lhs_horiz_5 < k_size) { \
148 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
149 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
150 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
151 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
152 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
153 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
154 } else if (lhs_horiz_4 < k_size) { \
155 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
156 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
157 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
158 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
159 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
160 } else if (lhs_horiz_3 < k_size) { \
161 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
162 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
163 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
164 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
165 } else if (lhs_horiz_2 < k_size) { \
166 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
167 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
168 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
169 } else if (lhs_horiz_1 < k_size) { \
170 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
171 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
172 } else if (lhs_horiz_0 < k_size) { \
173 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
174 } \
175 } \
176 \
177 const Index rhs_vert = base_k + load_idx_vert; \
178 if (!needs_edge_check || rhs_vert < k_size) { \
179 const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8; \
180 const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8; \
181 const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8; \
182 const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8; \
183 const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8; \
184 const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8; \
185 const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8; \
186 const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8; \
187 \
188 if (rhs_horiz_7 < n_size) { \
189 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
190 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
191 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
192 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
193 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
194 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
195 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
196 rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \
197 } else if (rhs_horiz_6 < n_size) { \
198 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
199 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
200 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
201 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
202 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
203 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
204 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
205 } else if (rhs_horiz_5 < n_size) { \
206 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
207 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
208 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
209 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
210 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
211 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
212 } else if (rhs_horiz_4 < n_size) { \
213 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
214 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
215 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
216 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
217 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
218 } else if (rhs_horiz_3 < n_size) { \
219 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
220 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
221 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
222 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
223 } else if (rhs_horiz_2 < n_size) { \
224 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
225 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
226 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
227 } else if (rhs_horiz_1 < n_size) { \
228 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
229 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
230 } else if (rhs_horiz_0 < n_size) { \
231 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
232 } \
233 } \
234 } \
235
236 #define writeRegToShmem(_) \
237 lhs_shmem[lhs_store_idx_0] = lhs_pf0; \
238 rhs_shmem[rhs_store_idx_0] = rhs_pf0; \
239 \
240 lhs_shmem[lhs_store_idx_1] = lhs_pf1; \
241 rhs_shmem[rhs_store_idx_1] = rhs_pf1; \
242 \
243 lhs_shmem[lhs_store_idx_2] = lhs_pf2; \
244 rhs_shmem[rhs_store_idx_2] = rhs_pf2; \
245 \
246 lhs_shmem[lhs_store_idx_3] = lhs_pf3; \
247 rhs_shmem[rhs_store_idx_3] = rhs_pf3; \
248 \
249 lhs_shmem[lhs_store_idx_4] = lhs_pf4; \
250 rhs_shmem[rhs_store_idx_4] = rhs_pf4; \
251 \
252 lhs_shmem[lhs_store_idx_5] = lhs_pf5; \
253 rhs_shmem[rhs_store_idx_5] = rhs_pf5; \
254 \
255 lhs_shmem[lhs_store_idx_6] = lhs_pf6; \
256 rhs_shmem[rhs_store_idx_6] = rhs_pf6; \
257 \
258 lhs_shmem[lhs_store_idx_7] = lhs_pf7; \
259 rhs_shmem[rhs_store_idx_7] = rhs_pf7; \
260
261 // declare and initialize result array
262 #define res(i, j) _res_##i##j
263 #define initResultRow(i) \
264 Scalar res(i, 0) = conv(0); \
265 Scalar res(i, 1) = conv(0); \
266 Scalar res(i, 2) = conv(0); \
267 Scalar res(i, 3) = conv(0); \
268 Scalar res(i, 4) = conv(0); \
269 Scalar res(i, 5) = conv(0); \
270 Scalar res(i, 6) = conv(0); \
271 Scalar res(i, 7) = conv(0); \
272
273 internal::scalar_cast_op<int, Scalar> conv;
274 initResultRow(0);
275 initResultRow(1);
276 initResultRow(2);
277 initResultRow(3);
278 initResultRow(4);
279 initResultRow(5);
280 initResultRow(6);
281 initResultRow(7);
282 #undef initResultRow
283
284 for (Index base_k = 0; base_k < k_size; base_k += 64) {
285 // wait for previous iteration to finish with shmem. Despite common sense,
286 // the code is a bit faster with this here then at bottom of loop
287 __syncthreads();
288
289 prefetchIntoRegisters(base_k);
290 writeRegToShmem();
291
292 #undef prefetchIntoRegisters
293 #undef writeRegToShmem
294
295 // wait for shared mem packing to be done before starting computation
296 __syncthreads();
297
298 // compute 8x8 matrix product by outer product. This involves packing one column
299 // of LHS and one row of RHS into registers (takes 16 registers).
300
301 #define lcol(i) _lcol##i
302 Scalar lcol(0);
303 Scalar lcol(1);
304 Scalar lcol(2);
305 Scalar lcol(3);
306 Scalar lcol(4);
307 Scalar lcol(5);
308 Scalar lcol(6);
309 Scalar lcol(7);
310
311 #define rrow(j) _rrow##j
312 Scalar rrow(0);
313 Scalar rrow(1);
314 Scalar rrow(2);
315 Scalar rrow(3);
316 Scalar rrow(4);
317 Scalar rrow(5);
318 Scalar rrow(6);
319 Scalar rrow(7);
320
321 // Now x corresponds to k, y to m, and z to n
322 const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
323 const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
324
325 #define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))]
326 #define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))]
327
328 #define loadData(i, j) \
329 lcol(0) = lhs_element(0, j); \
330 rrow(0) = rhs_element(i, 0); \
331 lcol(1) = lhs_element(1, j); \
332 rrow(1) = rhs_element(i, 1); \
333 lcol(2) = lhs_element(2, j); \
334 rrow(2) = rhs_element(i, 2); \
335 lcol(3) = lhs_element(3, j); \
336 rrow(3) = rhs_element(i, 3); \
337 lcol(4) = lhs_element(4, j); \
338 rrow(4) = rhs_element(i, 4); \
339 lcol(5) = lhs_element(5, j); \
340 rrow(5) = rhs_element(i, 5); \
341 lcol(6) = lhs_element(6, j); \
342 rrow(6) = rhs_element(i, 6); \
343 lcol(7) = lhs_element(7, j); \
344 rrow(7) = rhs_element(i, 7); \
345
346 #define computeCol(j) \
347 res(0, j) += lcol(0) * rrow(j); \
348 res(1, j) += lcol(1) * rrow(j); \
349 res(2, j) += lcol(2) * rrow(j); \
350 res(3, j) += lcol(3) * rrow(j); \
351 res(4, j) += lcol(4) * rrow(j); \
352 res(5, j) += lcol(5) * rrow(j); \
353 res(6, j) += lcol(6) * rrow(j); \
354 res(7, j) += lcol(7) * rrow(j); \
355
356 #define computePass(i) \
357 loadData(i, i); \
358 \
359 computeCol(0); \
360 computeCol(1); \
361 computeCol(2); \
362 computeCol(3); \
363 computeCol(4); \
364 computeCol(5); \
365 computeCol(6); \
366 computeCol(7); \
367
368 computePass(0);
369 computePass(1);
370 computePass(2);
371 computePass(3);
372 computePass(4);
373 computePass(5);
374 computePass(6);
375 computePass(7);
376
377 #undef lcol
378 #undef rrow
379 #undef lhs_element
380 #undef rhs_element
381 #undef loadData
382 #undef computeCol
383 #undef computePass
384 } // end loop over k
385
386 // we've now iterated over all of the large (ie width 64) k blocks and
387 // accumulated results in registers. At this point thread (x, y, z) contains
388 // the sum across all big k blocks of the product of little k block of index (x, y)
389 // with block of index (y, z). To compute the final output, we need to reduce
390 // the 8 threads over y by summation.
391 #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
392 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
393 #else
394 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor_sync(0xFFFFFFFF, res(i, j), mask)
395 #endif
396
397 #define reduceRow(i, mask) \
398 shuffleInc(i, 0, mask); \
399 shuffleInc(i, 1, mask); \
400 shuffleInc(i, 2, mask); \
401 shuffleInc(i, 3, mask); \
402 shuffleInc(i, 4, mask); \
403 shuffleInc(i, 5, mask); \
404 shuffleInc(i, 6, mask); \
405 shuffleInc(i, 7, mask); \
406
407 #define reduceMatrix(mask) \
408 reduceRow(0, mask); \
409 reduceRow(1, mask); \
410 reduceRow(2, mask); \
411 reduceRow(3, mask); \
412 reduceRow(4, mask); \
413 reduceRow(5, mask); \
414 reduceRow(6, mask); \
415 reduceRow(7, mask); \
416
417 // actually perform the reduction, now each thread of index (_, y, z)
418 // contains the correct values in its registers that belong in the output
419 // block
420 reduceMatrix(1);
421 reduceMatrix(2);
422 reduceMatrix(4);
423
424 #undef shuffleInc
425 #undef reduceRow
426 #undef reduceMatrix
427
428 // now we need to copy the 64 values into main memory. We can't split work
429 // among threads because all variables are in registers. There's 2 ways
430 // to do this:
431 // (1) have 1 thread do 64 writes from registers into global memory
432 // (2) have 1 thread do 64 writes into shared memory, and then 8 threads
433 // each do 8 writes into global memory. We can just overwrite the shared
434 // memory from the problem we just solved.
435 // (2) is slightly faster than (1) due to less branching and more ILP
436
437 // TODO: won't yield much gain, but could just use currently unused shared mem
438 // and then we won't have to sync
439 // wait for shared mem to be out of use
440 __syncthreads();
441
442 #define writeResultShmem(i, j) \
443 lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j); \
444
445 #define writeRow(i) \
446 writeResultShmem(i, 0); \
447 writeResultShmem(i, 1); \
448 writeResultShmem(i, 2); \
449 writeResultShmem(i, 3); \
450 writeResultShmem(i, 4); \
451 writeResultShmem(i, 5); \
452 writeResultShmem(i, 6); \
453 writeResultShmem(i, 7); \
454
455 if (threadIdx.x == 0) {
456 writeRow(0);
457 writeRow(1);
458 writeRow(2);
459 writeRow(3);
460 writeRow(4);
461 writeRow(5);
462 writeRow(6);
463 writeRow(7);
464 }
465 #undef writeResultShmem
466 #undef writeRow
467
468 const int max_i_write = numext::mini((int)((m_size - base_m - threadIdx.y + 7) / 8), 8);
469 const int max_j_write = numext::mini((int)((n_size - base_n - threadIdx.z + 7) / 8), 8);
470
471 if (threadIdx.x < max_i_write) {
472 if (max_j_write == 8) {
473 // TODO: can i trade bank conflicts for coalesced writes?
474 Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0];
475 Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1];
476 Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2];
477 Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3];
478 Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4];
479 Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5];
480 Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6];
481 Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7];
482
483 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0;
484 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1;
485 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2;
486 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3;
487 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4;
488 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5;
489 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6;
490 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7;
491 } else {
492 #pragma unroll 7
493 for (int j = 0; j < max_j_write; j++) {
494 Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j];
495 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val;
496 }
497 }
498 }
499 #undef res
500 }
501
502
503 template<typename Scalar, typename Index, typename LhsMapper,
504 typename RhsMapper, typename OutputMapper>
505 __global__ void
506 #if defined(EIGEN_HIPCC)
507 __launch_bounds__(512, 1)
508 #else
509 __launch_bounds__(512)
510 #endif
EigenContractionKernel(const LhsMapper lhs,const RhsMapper rhs,const OutputMapper output,const Index m_size,const Index n_size,const Index k_size)511 EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
512 const OutputMapper output,
513 const Index m_size, const Index n_size, const Index k_size) {
514 __shared__ Scalar lhs_shmem[72 * 64];
515 __shared__ Scalar rhs_shmem[72 * 64];
516
517 const Index m_block_idx = blockIdx.x;
518 const Index n_block_idx = blockIdx.y;
519
520 const Index base_m = 64 * m_block_idx;
521 const Index base_n = 64 * n_block_idx;
522
523 if (base_m + 63 < m_size && base_n + 63 < n_size) {
524 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
525 } else {
526 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
527 }
528 }
529
530
531 template<typename Index, typename LhsMapper,
532 typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
533 bool CHECK_RHS_BOUNDARY>
534 __device__ __forceinline__ void
EigenFloatContractionKernelInternal16x16(const LhsMapper lhs,const RhsMapper rhs,const OutputMapper output,float2 lhs_shmem2[][16],float2 rhs_shmem2[][8],const Index m_size,const Index n_size,const Index k_size,const Index base_m,const Index base_n)535 EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rhs,
536 const OutputMapper output, float2 lhs_shmem2[][16],
537 float2 rhs_shmem2[][8], const Index m_size,
538 const Index n_size, const Index k_size,
539 const Index base_m, const Index base_n) {
540
541 // prefetch registers
542 float4 lhs_pf0, rhs_pf0;
543
544 float4 results[4];
545 for (int i=0; i < 4; i++) {
546 results[i].x = results[i].y = results[i].z = results[i].w = 0;
547 }
548
549 #define prefetch_lhs(reg, row, col) \
550 if (!CHECK_LHS_BOUNDARY) { \
551 if (col < k_size) { \
552 reg =lhs.template loadPacket<float4,Unaligned>(row, col); \
553 } \
554 } else { \
555 if (col < k_size) { \
556 if (row + 3 < m_size) { \
557 reg =lhs.template loadPacket<float4,Unaligned>(row, col); \
558 } else if (row + 2 < m_size) { \
559 reg.x =lhs(row + 0, col); \
560 reg.y =lhs(row + 1, col); \
561 reg.z =lhs(row + 2, col); \
562 } else if (row + 1 < m_size) { \
563 reg.x =lhs(row + 0, col); \
564 reg.y =lhs(row + 1, col); \
565 } else if (row < m_size) { \
566 reg.x =lhs(row + 0, col); \
567 } \
568 } \
569 } \
570
571 Index lhs_vert = base_m+threadIdx.x*4;
572
573 for (Index k = 0; k < k_size; k += 16) {
574
575 lhs_pf0 = internal::pset1<float4>(0);
576 rhs_pf0 = internal::pset1<float4>(0);
577
578 Index lhs_horiz = threadIdx.y+k;
579 prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
580
581 Index rhs_vert = k+(threadIdx.x%4)*4;
582 Index rhs_horiz0 = (threadIdx.x>>2)+threadIdx.y*4+base_n;
583
584 if (!CHECK_RHS_BOUNDARY) {
585 if ((rhs_vert + 3) < k_size) {
586 // just CHECK_RHS_BOUNDARY
587 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
588 } else if (rhs_vert + 2 < k_size) {
589 // just CHECK_RHS_BOUNDARY
590 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
591 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
592 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
593 } else if (rhs_vert + 1 < k_size) {
594 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
595 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
596 } else if (rhs_vert < k_size) {
597 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
598 }
599 } else {
600 if (rhs_horiz0 < n_size) {
601 if ((rhs_vert + 3) < k_size) {
602 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
603 } else if ((rhs_vert + 2) < k_size) {
604 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
605 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
606 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
607 } else if ((rhs_vert + 1) < k_size) {
608 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
609 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
610 } else if (rhs_vert < k_size) {
611 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
612 }
613 }
614 }
615 float x1, x2 ;
616 // the following can be a bitwise operation..... some day.
617 if((threadIdx.x%8) < 4) {
618 x1 = rhs_pf0.y;
619 x2 = rhs_pf0.w;
620 } else {
621 x1 = rhs_pf0.x;
622 x2 = rhs_pf0.z;
623 }
624 #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
625 x1 = __shfl_xor(x1, 4);
626 x2 = __shfl_xor(x2, 4);
627 #else
628 x1 = __shfl_xor_sync(0xFFFFFFFF, x1, 4);
629 x2 = __shfl_xor_sync(0xFFFFFFFF, x2, 4);
630 #endif
631 if((threadIdx.x%8) < 4) {
632 rhs_pf0.y = x1;
633 rhs_pf0.w = x2;
634 } else {
635 rhs_pf0.x = x1;
636 rhs_pf0.z = x2;
637 }
638
639 // We have 64 features.
640 // Row 0 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 0, 1.
641 // Row 1 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 2, 3.
642 // ...
643 // Row 31 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 62, 63
644 // Row 32 -> times (2, 6, 10, 14, 3, 7, 11, 15) for features 0, 1
645 // ...
646 rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2][threadIdx.x%8] = make_float2(rhs_pf0.x, rhs_pf0.y);
647 rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2+32][threadIdx.x%8] = make_float2(rhs_pf0.z, rhs_pf0.w);
648
649 // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61)
650 // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61)
651 // ...
652 // Row 15 (time 15) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61)
653 // Row 16 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63)
654 // ...
655
656 lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y);
657 lhs_shmem2[threadIdx.y+16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w);
658
659
660 #define add_vals(fl1, fl2, fr1, fr2)\
661 results[0].x += fl1.x * fr1.x;\
662 results[0].y += fl1.y * fr1.x;\
663 results[0].z += fl2.x * fr1.x;\
664 results[0].w += fl2.y * fr1.x;\
665 \
666 results[1].x += fl1.x * fr1.y;\
667 results[1].y += fl1.y * fr1.y;\
668 results[1].z += fl2.x * fr1.y;\
669 results[1].w += fl2.y * fr1.y;\
670 \
671 results[2].x += fl1.x * fr2.x;\
672 results[2].y += fl1.y * fr2.x;\
673 results[2].z += fl2.x * fr2.x;\
674 results[2].w += fl2.y * fr2.x;\
675 \
676 results[3].x += fl1.x * fr2.y;\
677 results[3].y += fl1.y * fr2.y;\
678 results[3].z += fl2.x * fr2.y;\
679 results[3].w += fl2.y * fr2.y;\
680
681 __syncthreads();
682
683 // Do the multiplies.
684 #pragma unroll
685 for (int koff = 0; koff < 16; koff ++) {
686 // 32 x threads.
687 float2 fl1 = lhs_shmem2[koff][threadIdx.x];
688 float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x];
689
690 int start_feature = threadIdx.y * 4;
691 float2 fr1 = rhs_shmem2[(start_feature>>1) + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
692 float2 fr2 = rhs_shmem2[(start_feature>>1) + 1 + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
693
694 add_vals(fl1, fl2, fr1, fr2)
695 }
696 __syncthreads();
697 }
698
699 #undef prefetch_lhs
700 #undef add_vals
701
702 Index horiz_base = threadIdx.y*4+base_n;
703 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
704 for (int i = 0; i < 4; i++) {
705 output(lhs_vert, horiz_base + i) = results[i].x;
706 output(lhs_vert + 1, horiz_base + i) = results[i].y;
707 output(lhs_vert + 2, horiz_base + i) = results[i].z;
708 output(lhs_vert + 3, horiz_base + i) = results[i].w;
709 }
710 } else if (!CHECK_RHS_BOUNDARY) {
711 // CHECK LHS
712 if (lhs_vert + 3 < m_size) {
713 for (int i = 0; i < 4; i++) {
714 output(lhs_vert, horiz_base + i) = results[i].x;
715 output(lhs_vert + 1, horiz_base + i) = results[i].y;
716 output(lhs_vert + 2, horiz_base + i) = results[i].z;
717 output(lhs_vert + 3, horiz_base + i) = results[i].w;
718 }
719 } else if (lhs_vert + 2 < m_size) {
720 for (int i = 0; i < 4; i++) {
721 output(lhs_vert, horiz_base + i) = results[i].x;
722 output(lhs_vert + 1, horiz_base + i) = results[i].y;
723 output(lhs_vert + 2, horiz_base + i) = results[i].z;
724 }
725 } else if (lhs_vert + 1 < m_size) {
726 for (int i = 0; i < 4; i++) {
727 output(lhs_vert, horiz_base + i) = results[i].x;
728 output(lhs_vert + 1, horiz_base + i) = results[i].y;
729 }
730 } else if (lhs_vert < m_size) {
731 for (int i = 0; i < 4; i++) {
732 output(lhs_vert, horiz_base + i) = results[i].x;
733 }
734 }
735 } else if (!CHECK_LHS_BOUNDARY) {
736 // CHECK RHS
737 /*
738 int ncols_rem = fminf(n_size- horiz_base, 4);
739 for (int i = 0; i < ncols_rem; i++) {
740 output(lhs_vert, horiz_base + i) = results[i].x;
741 output(lhs_vert + 1, horiz_base + i) = results[i].y;
742 output(lhs_vert + 2, horiz_base + i) = results[i].z;
743 output(lhs_vert + 3, horiz_base + i) = results[i].w;
744 }*/
745 for (int i = 0; i < 4; i++) {
746 if (horiz_base+i < n_size) {
747 output(lhs_vert, horiz_base + i) = results[i].x;
748 output(lhs_vert + 1, horiz_base + i) = results[i].y;
749 output(lhs_vert + 2, horiz_base + i) = results[i].z;
750 output(lhs_vert + 3, horiz_base + i) = results[i].w;
751 }
752 }
753 } else {
754 // CHECK both boundaries.
755 for (int i = 0; i < 4; i++) {
756 if (horiz_base+i < n_size) {
757 if (lhs_vert < m_size)
758 output(lhs_vert, horiz_base + i) = results[i].x;
759 if (lhs_vert + 1 < m_size)
760 output(lhs_vert + 1, horiz_base + i) = results[i].y;
761 if (lhs_vert + 2 < m_size)
762 output(lhs_vert + 2, horiz_base + i) = results[i].z;
763 if (lhs_vert + 3 < m_size)
764 output(lhs_vert + 3, horiz_base + i) = results[i].w;
765 }
766 }
767 }
768 }
769
770
771 template<typename Index, typename LhsMapper,
772 typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
773 bool CHECK_RHS_BOUNDARY>
774 __device__ __forceinline__ void
EigenFloatContractionKernelInternal(const LhsMapper lhs,const RhsMapper rhs,const OutputMapper output,float2 lhs_shmem2[][32],float2 rhs_shmem2[][8],const Index m_size,const Index n_size,const Index k_size,const Index base_m,const Index base_n)775 EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
776 const OutputMapper output, float2 lhs_shmem2[][32],
777 float2 rhs_shmem2[][8], const Index m_size,
778 const Index n_size, const Index k_size,
779 const Index base_m, const Index base_n) {
780
781 // prefetch registers
782 float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3;
783 float4 rhs_pf0, rhs_pf1;
784
785 float4 results[8];
786 for (int i=0; i < 8; i++) {
787 results[i].x = results[i].y = results[i].z = results[i].w = 0;
788 }
789
790 Index lhs_vert = base_m+threadIdx.x*4+(threadIdx.y%4)*32;
791 for (Index k = 0; k < k_size; k += 32) {
792 lhs_pf0 = internal::pset1<float4>(0);
793 lhs_pf1 = internal::pset1<float4>(0);
794 lhs_pf2 = internal::pset1<float4>(0);
795 lhs_pf3 = internal::pset1<float4>(0);
796
797 rhs_pf0 = internal::pset1<float4>(0);
798 rhs_pf1 = internal::pset1<float4>(0);
799
800 if (!CHECK_LHS_BOUNDARY) {
801 if ((threadIdx.y/4+k+24) < k_size) {
802 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
803 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
804 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
805 lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
806 } else if ((threadIdx.y/4+k+16) < k_size) {
807 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
808 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
809 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
810 } else if ((threadIdx.y/4+k+8) < k_size) {
811 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
812 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
813 } else if ((threadIdx.y/4+k) < k_size) {
814 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
815 }
816 } else {
817 // just CHECK_LHS_BOUNDARY
818 if (lhs_vert + 3 < m_size) {
819 if ((threadIdx.y/4+k+24) < k_size) {
820 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
821 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
822 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
823 lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
824 } else if ((threadIdx.y/4+k+16) < k_size) {
825 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
826 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
827 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
828 } else if ((threadIdx.y/4+k+8) < k_size) {
829 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
830 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
831 } else if ((threadIdx.y/4+k) < k_size) {
832 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
833 }
834 } else if (lhs_vert + 2 < m_size) {
835 if ((threadIdx.y/4+k+24) < k_size) {
836 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
837 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
838 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
839 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
840 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
841 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
842 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
843 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
844 lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
845 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
846 lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
847 lhs_pf3.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+24));
848 } else if ((threadIdx.y/4+k+16) < k_size) {
849 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
850 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
851 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
852 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
853 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
854 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
855 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
856 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
857 lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
858 } else if ((threadIdx.y/4+k+8) < k_size) {
859 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
860 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
861 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
862 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
863 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
864 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
865 } else if ((threadIdx.y/4+k) < k_size) {
866 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
867 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
868 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
869 }
870 } else if (lhs_vert + 1 < m_size) {
871 if ((threadIdx.y/4+k+24) < k_size) {
872 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
873 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
874 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
875 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
876 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
877 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
878 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
879 lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
880 } else if ((threadIdx.y/4+k+16) < k_size) {
881 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
882 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
883 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
884 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
885 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
886 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
887 } else if ((threadIdx.y/4+k+8) < k_size) {
888 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
889 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
890 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
891 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
892 } else if ((threadIdx.y/4+k) < k_size) {
893 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
894 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
895 }
896 } else if (lhs_vert < m_size) {
897 if ((threadIdx.y/4+k+24) < k_size) {
898 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
899 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
900 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
901 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
902 } else if ((threadIdx.y/4+k+16) < k_size) {
903 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
904 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
905 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
906 } else if ((threadIdx.y/4+k+8) < k_size) {
907 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
908 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
909 } else if ((threadIdx.y/4+k) < k_size) {
910 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
911 }
912 }
913 }
914 __syncthreads();
915 Index rhs_vert = k+threadIdx.x*4;
916 Index rhs_horiz0 = threadIdx.y*2+base_n;
917 Index rhs_horiz1 = threadIdx.y*2+1+base_n;
918 if (!CHECK_RHS_BOUNDARY) {
919 if ((rhs_vert + 3) < k_size) {
920 // just CHECK_RHS_BOUNDARY
921 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
922 rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
923 } else if (rhs_vert + 2 < k_size) {
924 // just CHECK_RHS_BOUNDARY
925 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
926 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
927 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
928 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
929 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
930 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
931 } else if (rhs_vert + 1 < k_size) {
932 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
933 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
934 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
935 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
936 } else if (rhs_vert < k_size) {
937 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
938 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
939 }
940 } else {
941 if (rhs_horiz1 < n_size) {
942 if ((rhs_vert + 3) < k_size) {
943 // just CHECK_RHS_BOUNDARY
944 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
945 rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
946 } else if (rhs_vert + 2 < k_size) {
947 // just CHECK_RHS_BOUNDARY
948 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
949 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
950 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
951 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
952 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
953 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
954 } else if (k+threadIdx.x*4 + 1 < k_size) {
955 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
956 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
957 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
958 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
959 } else if (k+threadIdx.x*4 < k_size) {
960 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
961 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
962 }
963 } else if (rhs_horiz0 < n_size) {
964 if ((rhs_vert + 3) < k_size) {
965 // just CHECK_RHS_BOUNDARY
966 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
967 } else if ((rhs_vert + 2) < k_size) {
968 // just CHECK_RHS_BOUNDARY
969 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
970 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
971 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
972 } else if ((rhs_vert + 1) < k_size) {
973 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
974 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
975 } else if (rhs_vert < k_size) {
976 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
977 }
978 }
979 }
980 __syncthreads();
981 // Loaded. Do computation
982 // Row 0 -> times (0, 4, 8, .. 28) for features 0, 1.
983 // Row 1 -> times (0, 4, 8, .. 28) for features 2, 3.
984 // ..
985 // Row 31 -> times (0, 4, 8, .. 28) for features 62, 63
986 rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x);
987 // Row 32 -> times (1, 5, 9, .. 29) for features 0, 1.
988 // Row 33 -> times (1, 5, 9, .. 29) for features 2, 3.
989 // ..
990 rhs_shmem2[threadIdx.y+32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y);
991 // Row 64 -> times (2, 6, 10, .. 30) for features 0, 1.
992 // Row 65 -> times (2, 6, 10, .. 30) for features 2, 3.
993 rhs_shmem2[threadIdx.y+64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z);
994 // Row 96 -> times (3, 7, 11, .. 31) for features 0, 1.
995 // Row 97 -> times (3, 7, 11, .. 31) for features 2, 3.
996 rhs_shmem2[threadIdx.y+96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w);
997
998 // LHS.
999 // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) .. (124, 125)
1000 // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) .. (124, 125)
1001 // ...
1002 // Row 8 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) .. (126, 127)
1003 // Row 15 (time 7) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) .. (126, 127)
1004
1005
1006 #define add_vals(a_feat1, a_feat2, f1, f2, f3, f4)\
1007 results[0].x += a_feat1.x * f1.x;\
1008 results[1].x += a_feat1.x * f1.y;\
1009 results[2].x += a_feat1.x * f2.x;\
1010 results[3].x += a_feat1.x * f2.y;\
1011 results[4].x += a_feat1.x * f3.x;\
1012 results[5].x += a_feat1.x * f3.y;\
1013 results[6].x += a_feat1.x * f4.x;\
1014 results[7].x += a_feat1.x * f4.y;\
1015 \
1016 results[0].y += a_feat1.y * f1.x;\
1017 results[1].y += a_feat1.y * f1.y;\
1018 results[2].y += a_feat1.y * f2.x;\
1019 results[3].y += a_feat1.y * f2.y;\
1020 results[4].y += a_feat1.y * f3.x;\
1021 results[5].y += a_feat1.y * f3.y;\
1022 results[6].y += a_feat1.y * f4.x;\
1023 results[7].y += a_feat1.y * f4.y;\
1024 \
1025 results[0].z += a_feat2.x * f1.x;\
1026 results[1].z += a_feat2.x * f1.y;\
1027 results[2].z += a_feat2.x * f2.x;\
1028 results[3].z += a_feat2.x * f2.y;\
1029 results[4].z += a_feat2.x * f3.x;\
1030 results[5].z += a_feat2.x * f3.y;\
1031 results[6].z += a_feat2.x * f4.x;\
1032 results[7].z += a_feat2.x * f4.y;\
1033 \
1034 results[0].w += a_feat2.y * f1.x;\
1035 results[1].w += a_feat2.y * f1.y;\
1036 results[2].w += a_feat2.y * f2.x;\
1037 results[3].w += a_feat2.y * f2.y;\
1038 results[4].w += a_feat2.y * f3.x;\
1039 results[5].w += a_feat2.y * f3.y;\
1040 results[6].w += a_feat2.y * f4.x;\
1041 results[7].w += a_feat2.y * f4.y;\
1042
1043 lhs_shmem2[threadIdx.y/4][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.x, lhs_pf0.y);
1044 lhs_shmem2[threadIdx.y/4+8][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.x, lhs_pf1.y);
1045 lhs_shmem2[threadIdx.y/4+16][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.x, lhs_pf2.y);
1046 lhs_shmem2[threadIdx.y/4+24][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.x, lhs_pf3.y);
1047
1048 lhs_shmem2[threadIdx.y/4 + 32][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.z, lhs_pf0.w);
1049 lhs_shmem2[threadIdx.y/4 + 40][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.z, lhs_pf1.w);
1050 lhs_shmem2[threadIdx.y/4 + 48][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.z, lhs_pf2.w);
1051 lhs_shmem2[threadIdx.y/4 + 56][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.z, lhs_pf3.w);
1052
1053 __syncthreads();
1054
1055 // Do the multiplies.
1056 #pragma unroll
1057 for (int koff = 0; koff < 32; koff ++) {
1058 float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8];
1059 float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8];
1060
1061 // first feature is at (threadIdx.y/4) * 8 last is at start + 8.
1062 int start_feature = (threadIdx.y / 4) * 8;
1063
1064 float2 br1 = rhs_shmem2[start_feature/2 + (koff % 4) * 32][koff/4];
1065 float2 br2 = rhs_shmem2[start_feature/2 + 1 + (koff % 4) * 32][koff/4];
1066 float2 br3 = rhs_shmem2[start_feature/2 + 2 + (koff % 4) * 32][koff/4];
1067 float2 br4 = rhs_shmem2[start_feature/2 + 3 + (koff % 4) * 32][koff/4];
1068
1069 add_vals(a3, a4, br1, br2, br3, br4)
1070 }
1071 __syncthreads();
1072 } // end loop over k
1073
1074 __syncthreads();
1075 Index horiz_base = (threadIdx.y/4)*8+base_n;
1076 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
1077 for (int i = 0; i < 8; i++) {
1078 output(lhs_vert, horiz_base + i) = results[i].x;
1079 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1080 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1081 output(lhs_vert + 3, horiz_base + i) = results[i].w;
1082 }
1083 } else if (!CHECK_RHS_BOUNDARY) {
1084 if (lhs_vert + 3 < m_size) {
1085 for (int i = 0; i < 8; i++) {
1086 output(lhs_vert, horiz_base + i) = results[i].x;
1087 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1088 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1089 output(lhs_vert + 3, horiz_base + i) = results[i].w;
1090 }
1091 } else if (lhs_vert + 2 < m_size) {
1092 for (int i = 0; i < 8; i++) {
1093 output(lhs_vert, horiz_base + i) = results[i].x;
1094 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1095 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1096 }
1097 } else if (lhs_vert + 1 < m_size) {
1098 for (int i = 0; i < 8; i++) {
1099 output(lhs_vert, horiz_base + i) = results[i].x;
1100 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1101 }
1102 } else if (lhs_vert < m_size) {
1103 for (int i = 0; i < 8; i++) {
1104 output(lhs_vert, horiz_base + i) = results[i].x;
1105 }
1106 }
1107 } else if (!CHECK_LHS_BOUNDARY) {
1108 // CHECK BOUNDARY_B
1109 for (int i = 0; i < 8; i++) {
1110 if (horiz_base + i < n_size) {
1111 output(lhs_vert, horiz_base + i) = results[i].x;
1112 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1113 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1114 output(lhs_vert + 3, horiz_base + i) = results[i].w;
1115 }
1116 }
1117 } else {
1118 // CHECK both boundaries.
1119 for (int i = 0; i < 8; i++) {
1120 if (horiz_base + i < n_size) {
1121 if (lhs_vert < m_size)
1122 output(lhs_vert, horiz_base + i) = results[i].x;
1123 if (lhs_vert + 1 < m_size)
1124 output(lhs_vert + 1, horiz_base + i) = results[i].y;
1125 if (lhs_vert + 2 < m_size)
1126 output(lhs_vert + 2, horiz_base + i) = results[i].z;
1127 if (lhs_vert + 3 < m_size)
1128 output(lhs_vert + 3, horiz_base + i) = results[i].w;
1129 }
1130 }
1131 }
1132 }
1133
1134
1135 template<typename Index, typename LhsMapper,
1136 typename RhsMapper, typename OutputMapper>
1137 __global__ void
1138 #if defined(EIGEN_HIPCC)
1139 __launch_bounds__(256, 1)
1140 #else
1141 __launch_bounds__(256)
1142 #endif
EigenFloatContractionKernel(const LhsMapper lhs,const RhsMapper rhs,const OutputMapper output,const Index m_size,const Index n_size,const Index k_size)1143 EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
1144 const OutputMapper output,
1145 const Index m_size, const Index n_size, const Index k_size) {
1146 __shared__ float2 lhs_shmem[64*32];
1147 __shared__ float2 rhs_shmem[128*8];
1148
1149 typedef float2 LHS_MEM[64][32];
1150 typedef float2 RHS_MEM[128][8];
1151
1152 const Index m_block_idx = blockIdx.x;
1153 const Index n_block_idx = blockIdx.y;
1154
1155 const Index base_m = 128 * m_block_idx;
1156 const Index base_n = 64 * n_block_idx;
1157
1158 bool check_rhs = (base_n + 63) >= n_size;
1159 bool check_lhs128 = (base_m + 127) >= m_size;
1160
1161 if (!check_rhs) {
1162 if (!check_lhs128) {
1163 // >= 128 rows left
1164 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
1165 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1166 } else {
1167 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
1168 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1169 }
1170 } else {
1171 if (!check_lhs128) {
1172 // >= 128 rows left
1173 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
1174 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1175 } else {
1176 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
1177 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1178 }
1179 }
1180 }
1181
1182 template<typename Index, typename LhsMapper,
1183 typename RhsMapper, typename OutputMapper>
1184 __global__ void
1185 #if defined(EIGEN_HIPCC)
1186 __launch_bounds__(256, 1)
1187 #else
1188 __launch_bounds__(256)
1189 #endif
EigenFloatContractionKernel16x16(const LhsMapper lhs,const RhsMapper rhs,const OutputMapper output,const Index m_size,const Index n_size,const Index k_size)1190 EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs,
1191 const OutputMapper output,
1192 const Index m_size, const Index n_size, const Index k_size) {
1193 __shared__ float2 lhs_shmem[32][16];
1194 __shared__ float2 rhs_shmem[64][8];
1195
1196 const Index m_block_idx = blockIdx.x;
1197 const Index n_block_idx = blockIdx.y;
1198
1199 const Index base_m = 64 * m_block_idx;
1200 const Index base_n = 64 * n_block_idx;
1201
1202 if (base_m + 63 < m_size) {
1203 if (base_n + 63 < n_size) {
1204 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1205 } else {
1206 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1207 }
1208 } else {
1209 if (base_n + 63 < n_size) {
1210 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1211 } else {
1212 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1213 }
1214 }
1215 }
1216
1217
1218 template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
1219 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> :
1220 public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> > {
1221
1222 typedef GpuDevice Device;
1223
1224 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
1225 typedef TensorContractionEvaluatorBase<Self> Base;
1226
1227 typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
1228 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
1229 typedef typename XprType::Index Index;
1230 typedef typename XprType::CoeffReturnType CoeffReturnType;
1231 typedef typename PacketType<CoeffReturnType, GpuDevice>::type PacketReturnType;
1232
1233 enum {
1234 Layout = TensorEvaluator<LeftArgType, Device>::Layout,
1235 };
1236
1237 // Most of the code is assuming that both input tensors are ColMajor. If the
1238 // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
1239 // If we want to compute A * B = C, where A is LHS and B is RHS, the code
1240 // will pretend B is LHS and A is RHS.
1241 typedef typename internal::conditional<
1242 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
1243 typedef typename internal::conditional<
1244 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
1245
1246 static const int LDims =
1247 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
1248 static const int RDims =
1249 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
1250 static const int ContractDims = internal::array_size<Indices>::value;
1251
1252 typedef array<Index, LDims> left_dim_mapper_t;
1253 typedef array<Index, RDims> right_dim_mapper_t;
1254
1255 typedef array<Index, ContractDims> contract_t;
1256 typedef array<Index, LDims - ContractDims> left_nocontract_t;
1257 typedef array<Index, RDims - ContractDims> right_nocontract_t;
1258
1259 static const int NumDims = LDims + RDims - 2 * ContractDims;
1260
1261 typedef DSizes<Index, NumDims> Dimensions;
1262
1263 // typedefs needed in evalTo
1264 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
1265 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
1266
1267 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
1268 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
1269
1270 typedef typename LeftEvaluator::Dimensions LeftDimensions;
1271 typedef typename RightEvaluator::Dimensions RightDimensions;
1272
1273 TensorEvaluator(const XprType& op, const Device& device) :
1274 Base(op, device)
1275 {
1276 EIGEN_STATIC_ASSERT( (internal::is_same<OutputKernelType, const NoOpOutputKernel>::value),
1277 GPU_TENSOR_CONTRACTION_DOES_NOT_SUPPORT_OUTPUT_KERNELS);
1278 }
1279
1280 // We need to redefine this method to make nvcc happy
1281 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
1282 this->m_leftImpl.evalSubExprsIfNeeded(NULL);
1283 this->m_rightImpl.evalSubExprsIfNeeded(NULL);
1284 if (data) {
1285 evalTo(data);
1286 return false;
1287 } else {
1288 this->m_result = static_cast<Scalar *>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar)));
1289 evalTo(this->m_result);
1290 return true;
1291 }
1292 }
1293
1294 void evalTo(Scalar* buffer) const {
1295 if (this->m_lhs_inner_dim_contiguous) {
1296 if (this->m_rhs_inner_dim_contiguous) {
1297 if (this->m_rhs_inner_dim_reordered) {
1298 evalTyped<true, true, true, Unaligned>(buffer);
1299 }
1300 else {
1301 evalTyped<true, true, false, Unaligned>(buffer);
1302 }
1303 }
1304 else {
1305 if (this->m_rhs_inner_dim_reordered) {
1306 evalTyped<true, false, true, Unaligned>(buffer);
1307 }
1308 else {
1309 evalTyped<true, false, false, Unaligned>(buffer);
1310 }
1311 }
1312 }
1313 else {
1314 if (this->m_rhs_inner_dim_contiguous) {
1315 if (this->m_rhs_inner_dim_reordered) {
1316 evalTyped<false, true, true, Unaligned>(buffer);
1317 }
1318 else {
1319 evalTyped<false, true, false, Unaligned>(buffer);
1320 }
1321 }
1322 else {
1323 if (this->m_rhs_inner_dim_reordered) {
1324 evalTyped<false, false, true, Unaligned>(buffer);
1325 }
1326 else {
1327 evalTyped<false, false, false, Unaligned>(buffer);
1328 }
1329 }
1330 }
1331 }
1332
1333 template <typename LhsScalar, typename RhsScalar, typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels {
1334 static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
1335 const Index m_blocks = (m + 63) / 64;
1336 const Index n_blocks = (n + 63) / 64;
1337 const dim3 num_blocks(m_blocks, n_blocks, 1);
1338 const dim3 block_size(8, 8, 8);
1339 LAUNCH_GPU_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1340 }
1341 };
1342
1343 template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels<float, float, Index, LhsMapper, RhsMapper, OutputMapper> {
1344 static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
1345 if (m < 768 || n < 768) {
1346 const Index m_blocks = (m + 63) / 64;
1347 const Index n_blocks = (n + 63) / 64;
1348 const dim3 num_blocks(m_blocks, n_blocks, 1);
1349 const dim3 block_size(16, 16, 1);
1350 LAUNCH_GPU_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1351 } else {
1352 const Index m_blocks = (m + 127) / 128;
1353 const Index n_blocks = (n + 63) / 64;
1354 const dim3 num_blocks(m_blocks, n_blocks, 1);
1355 const dim3 block_size(8, 32, 1);
1356 LAUNCH_GPU_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1357 }
1358 }
1359 };
1360
1361 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
1362 void evalTyped(Scalar* buffer) const {
1363 // columns in left side, rows in right side
1364 const Index k = this->m_k_size;
1365 EIGEN_UNUSED_VARIABLE(k)
1366
1367 // rows in left side
1368 const Index m = this->m_i_size;
1369
1370 // columns in right side
1371 const Index n = this->m_j_size;
1372
1373 // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
1374 this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
1375
1376 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
1377 LeftEvaluator, left_nocontract_t,
1378 contract_t, 4,
1379 lhs_inner_dim_contiguous,
1380 false, Unaligned> LhsMapper;
1381
1382 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
1383 RightEvaluator, right_nocontract_t,
1384 contract_t, 4,
1385 rhs_inner_dim_contiguous,
1386 rhs_inner_dim_reordered, Unaligned> RhsMapper;
1387
1388 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1389
1390
1391 // initialize data mappers
1392 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
1393 this->m_left_contracting_strides, this->m_k_strides);
1394
1395 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
1396 this->m_right_contracting_strides, this->m_k_strides);
1397
1398 OutputMapper output(buffer, m);
1399
1400 #if defined(EIGEN_USE_HIP)
1401 setGpuSharedMemConfig(hipSharedMemBankSizeEightByte);
1402 #else
1403 setGpuSharedMemConfig(cudaSharedMemBankSizeEightByte);
1404 #endif
1405
1406 LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output, m, n, k, this->m_device);
1407 }
1408 };
1409
1410 } // end namespace Eigen
1411
1412 #endif // EIGEN_USE_GPU and EIGEN_GPUCC
1413 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
1414