1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2020 Everton Constantino ([email protected])
5 // Copyright (C) 2021 Chip Kerchner ([email protected])
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11 #ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
12 #define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
13
14 #pragma GCC target("cpu=power10")
15
16 #ifdef __has_builtin
17 #if !__has_builtin(__builtin_vsx_assemble_pair)
18 #define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
19 #endif
20 #endif
21
22 namespace Eigen {
23
24 namespace internal {
25
26 template<typename Scalar, typename Packet>
bsetzeroMMA(__vector_quad * acc)27 EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc)
28 {
29 __builtin_mma_xxsetaccz(acc);
30 }
31
32 template<typename DataMapper, typename Index, typename Packet, const Index accCols>
storeAccumulator(Index i,Index j,const DataMapper & data,const Packet & alpha,__vector_quad * acc)33 EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, Index j, const DataMapper& data, const Packet& alpha, __vector_quad* acc)
34 {
35 PacketBlock<Packet, 4> result;
36 __builtin_mma_disassemble_acc(&result.packet, acc);
37
38 PacketBlock<Packet, 4> tRes;
39 bload<DataMapper, Packet, Index, accCols, 0, ColMajor>(tRes, data, i, j);
40
41 bscale<Packet>(tRes, result, alpha);
42
43 data.template storePacketBlock<Packet, 4>(i, j, tRes);
44 }
45
46 template<typename DataMapper, typename Index, typename Packet, typename Packetc, const Index accColsC, int N>
storeComplexAccumulator(Index i,Index j,const DataMapper & data,const Packet & alphaReal,const Packet & alphaImag,__vector_quad * accReal,__vector_quad * accImag)47 EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, Index j, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag)
48 {
49 PacketBlock<Packet, 4> resultReal, resultImag;
50 __builtin_mma_disassemble_acc(&resultReal.packet, accReal);
51 __builtin_mma_disassemble_acc(&resultImag.packet, accImag);
52
53 PacketBlock<Packetc, 8> tRes;
54 bload<DataMapper, Packetc, Index, accColsC, N, ColMajor>(tRes, data, i, j);
55
56 PacketBlock<Packet,4> taccReal, taccImag;
57 bscalec<Packet,4>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag);
58
59 PacketBlock<Packetc, 4> acc1, acc2;
60 bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc1, acc2);
61
62 data.template storePacketBlock<Packetc, 4>(i + N*accColsC, j, acc1);
63 data.template storePacketBlock<Packetc, 4>(i + (N+1)*accColsC, j, acc2);
64 }
65
66 // Defaults to float32, since Eigen still supports C++03 we can't use default template arguments
67 template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
pgerMMA(__vector_quad * acc,const RhsPacket & a,const LhsPacket & b)68 EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const RhsPacket& a, const LhsPacket& b)
69 {
70 if(NegativeAccumulate)
71 {
72 __builtin_mma_xvf32gernp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
73 } else {
74 __builtin_mma_xvf32gerpp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
75 }
76 }
77
78 template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
pgerMMA(__vector_quad * acc,const PacketBlock<Packet2d,2> & a,const Packet2d & b)79 EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const PacketBlock<Packet2d,2>& a, const Packet2d& b)
80 {
81 __vector_pair* a0 = (__vector_pair *)(&a.packet[0]);
82 if(NegativeAccumulate)
83 {
84 __builtin_mma_xvf64gernp(acc, *a0, (__vector unsigned char)b);
85 } else {
86 __builtin_mma_xvf64gerpp(acc, *a0, (__vector unsigned char)b);
87 }
88 }
89
90 template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
pgerMMA(__vector_quad * acc,const __vector_pair & a,const Packet2d & b)91 EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const __vector_pair& a, const Packet2d& b)
92 {
93 if(NegativeAccumulate)
94 {
95 __builtin_mma_xvf64gernp(acc, (__vector_pair)a, (__vector unsigned char)b);
96 } else {
97 __builtin_mma_xvf64gerpp(acc, (__vector_pair)a, (__vector unsigned char)b);
98 }
99 }
100
101 template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
pgerMMA(__vector_quad *,const __vector_pair &,const Packet4f &)102 EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad*, const __vector_pair&, const Packet4f&)
103 {
104 // Just for compilation
105 }
106
107 template<typename Scalar, typename Packet, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
pgercMMA(__vector_quad * accReal,__vector_quad * accImag,const Packet & lhsV,const Packet & lhsVi,const RhsPacket & rhsV,const RhsPacket & rhsVi)108 EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag, const Packet& lhsV, const Packet& lhsVi, const RhsPacket& rhsV, const RhsPacket& rhsVi)
109 {
110 pgerMMA<Packet, RhsPacket, false>(accReal, rhsV, lhsV);
111 if(LhsIsReal) {
112 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
113 } else {
114 if(!RhsIsReal) {
115 pgerMMA<Packet, RhsPacket, ConjugateLhs == ConjugateRhs>(accReal, rhsVi, lhsVi);
116 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
117 } else {
118 EIGEN_UNUSED_VARIABLE(rhsVi);
119 }
120 pgerMMA<Packet, RhsPacket, ConjugateLhs>(accImag, rhsV, lhsVi);
121 }
122 }
123
124 // This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled.
125 template<typename Scalar, typename Packet>
ploadRhsMMA(const Scalar * rhs,Packet & rhsV)126 EIGEN_ALWAYS_INLINE void ploadRhsMMA(const Scalar* rhs, Packet& rhsV)
127 {
128 rhsV = ploadRhs<Scalar, Packet>((const Scalar*)(rhs));
129 }
130
131 template<>
132 EIGEN_ALWAYS_INLINE void ploadRhsMMA<double, PacketBlock<Packet2d, 2> >(const double* rhs, PacketBlock<Packet2d, 2>& rhsV)
133 {
134 rhsV.packet[0] = ploadRhs<double, Packet2d>((const double *)((Packet2d *)rhs ));
135 rhsV.packet[1] = ploadRhs<double, Packet2d>((const double *)(((Packet2d *)rhs) + 1));
136 }
137
138 template<>
139 EIGEN_ALWAYS_INLINE void ploadRhsMMA<double, __vector_pair>(const double* rhs, __vector_pair& rhsV)
140 {
141 #if EIGEN_COMP_LLVM
142 __builtin_vsx_assemble_pair(&rhsV,
143 (__vector unsigned char)(ploadRhs<double, Packet2d>((const double *)(((Packet2d *)rhs) + 1))),
144 (__vector unsigned char)(ploadRhs<double, Packet2d>((const double *)((Packet2d *)rhs ))));
145 #else
146 __asm__ ("lxvp %x0,%1" : "=wa" (rhsV) : "Y" (*rhs));
147 #endif
148 }
149
150 template<>
ploadRhsMMA(const float *,__vector_pair &)151 EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&)
152 {
153 // Just for compilation
154 }
155
156 // PEEL_MMA loop factor.
157 #define PEEL_MMA 7
158
159 #define MICRO_MMA_UNROLL(func) \
160 func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
161
162 #define MICRO_MMA_LOAD_ONE(iter) \
163 if (unroll_factor > iter) { \
164 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr##iter); \
165 lhs_ptr##iter += accCols; \
166 } else { \
167 EIGEN_UNUSED_VARIABLE(lhsV##iter); \
168 }
169
170 #define MICRO_MMA_WORK_ONE(iter, type, peel) \
171 if (unroll_factor > iter) { \
172 pgerMMA<Packet, type, false>(&accZero##iter, rhsV##peel, lhsV##iter); \
173 }
174
175 #define MICRO_MMA_TYPE_PEEL(func, func2, type, peel) \
176 if (PEEL_MMA > peel) { \
177 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
178 ploadRhsMMA<Scalar, type>(rhs_ptr + (accRows * peel), rhsV##peel); \
179 MICRO_MMA_UNROLL(func2); \
180 func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
181 func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel) \
182 } else { \
183 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
184 }
185
186 #define MICRO_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
187 type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \
188 MICRO_MMA_TYPE_PEEL(func,func2,type,0); MICRO_MMA_TYPE_PEEL(func,func2,type,1); \
189 MICRO_MMA_TYPE_PEEL(func,func2,type,2); MICRO_MMA_TYPE_PEEL(func,func2,type,3); \
190 MICRO_MMA_TYPE_PEEL(func,func2,type,4); MICRO_MMA_TYPE_PEEL(func,func2,type,5); \
191 MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7); \
192 MICRO_MMA_TYPE_PEEL(func,func2,type,8); MICRO_MMA_TYPE_PEEL(func,func2,type,9);
193
194 #define MICRO_MMA_UNROLL_TYPE_ONE(func, func2, type) \
195 type rhsV0; \
196 MICRO_MMA_TYPE_PEEL(func,func2,type,0);
197
198 #define MICRO_MMA_ONE_PEEL \
199 if (sizeof(Scalar) == sizeof(float)) { \
200 MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
201 } else { \
202 MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
203 } \
204 rhs_ptr += (accRows * PEEL_MMA);
205
206 #define MICRO_MMA_ONE \
207 if (sizeof(Scalar) == sizeof(float)) { \
208 MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
209 } else { \
210 MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
211 } \
212 rhs_ptr += accRows;
213
214 #define MICRO_MMA_DST_PTR_ONE(iter) \
215 if (unroll_factor > iter) { \
216 bsetzeroMMA<Scalar, Packet>(&accZero##iter); \
217 } else { \
218 EIGEN_UNUSED_VARIABLE(accZero##iter); \
219 }
220
221 #define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE)
222
223 #define MICRO_MMA_SRC_PTR_ONE(iter) \
224 if (unroll_factor > iter) { \
225 lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \
226 } else { \
227 EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
228 }
229
230 #define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_MMA_SRC_PTR_ONE)
231
232 #define MICRO_MMA_PREFETCH_ONE(iter) \
233 if (unroll_factor > iter) { \
234 EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
235 }
236
237 #define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_MMA_PREFETCH_ONE)
238
239 #define MICRO_MMA_STORE_ONE(iter) \
240 if (unroll_factor > iter) { \
241 storeAccumulator<DataMapper, Index, Packet, accCols>(row + iter*accCols, col, res, pAlpha, &accZero##iter); \
242 }
243
244 #define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE)
245
246 template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols>
gemm_unrolled_MMA_iteration(const DataMapper & res,const Scalar * lhs_base,const Scalar * rhs_base,Index depth,Index strideA,Index offsetA,Index & row,Index col,const Packet & pAlpha)247 EIGEN_STRONG_INLINE void gemm_unrolled_MMA_iteration(
248 const DataMapper& res,
249 const Scalar* lhs_base,
250 const Scalar* rhs_base,
251 Index depth,
252 Index strideA,
253 Index offsetA,
254 Index& row,
255 Index col,
256 const Packet& pAlpha)
257 {
258 const Scalar* rhs_ptr = rhs_base;
259 const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
260 __vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
261
262 MICRO_MMA_SRC_PTR
263 MICRO_MMA_DST_PTR
264
265 Index k = 0;
266 for(; k + PEEL_MMA <= depth; k+= PEEL_MMA)
267 {
268 EIGEN_POWER_PREFETCH(rhs_ptr);
269 MICRO_MMA_PREFETCH
270 MICRO_MMA_ONE_PEEL
271 }
272 for(; k < depth; k++)
273 {
274 MICRO_MMA_ONE
275 }
276 MICRO_MMA_STORE
277
278 row += unroll_factor*accCols;
279 }
280
281 template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
gemmMMA(const DataMapper & res,const Scalar * blockA,const Scalar * blockB,Index rows,Index depth,Index cols,Scalar alpha,Index strideA,Index strideB,Index offsetA,Index offsetB)282 void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
283 {
284 const Index remaining_rows = rows % accCols;
285 const Index remaining_cols = cols % accRows;
286
287 if( strideA == -1 ) strideA = depth;
288 if( strideB == -1 ) strideB = depth;
289
290 const Packet pAlpha = pset1<Packet>(alpha);
291 const Packet pMask = bmask<Packet>((const int)(remaining_rows));
292
293 Index col = 0;
294 for(; col + accRows <= cols; col += accRows)
295 {
296 const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
297 const Scalar* lhs_base = blockA;
298
299 Index row = 0;
300 #define MAX_MMA_UNROLL 7
301 while(row + MAX_MMA_UNROLL*accCols <= rows) {
302 gemm_unrolled_MMA_iteration<MAX_MMA_UNROLL, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
303 }
304 switch( (rows-row)/accCols ) {
305 #if MAX_MMA_UNROLL > 7
306 case 7:
307 gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
308 break;
309 #endif
310 #if MAX_MMA_UNROLL > 6
311 case 6:
312 gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
313 break;
314 #endif
315 #if MAX_MMA_UNROLL > 5
316 case 5:
317 gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
318 break;
319 #endif
320 #if MAX_MMA_UNROLL > 4
321 case 4:
322 gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
323 break;
324 #endif
325 #if MAX_MMA_UNROLL > 3
326 case 3:
327 gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
328 break;
329 #endif
330 #if MAX_MMA_UNROLL > 2
331 case 2:
332 gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
333 break;
334 #endif
335 #if MAX_MMA_UNROLL > 1
336 case 1:
337 gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
338 break;
339 #endif
340 default:
341 break;
342 }
343 #undef MAX_MMA_UNROLL
344
345 if(remaining_rows > 0)
346 {
347 gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
348 }
349 }
350
351 if(remaining_cols > 0)
352 {
353 const Scalar* rhs_base = blockB + col*strideB + remaining_cols*offsetB;
354 const Scalar* lhs_base = blockA;
355
356 for(; col < cols; col++)
357 {
358 Index row = 0;
359
360 gemm_unrolled_col<Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha);
361
362 if (remaining_rows > 0)
363 {
364 gemm_extra_col<Scalar, Packet, DataMapper, Index, accRows>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha);
365 }
366 rhs_base++;
367 }
368 }
369 }
370
371 #define accColsC (accCols / 2)
372 #define advanceRows ((LhsIsReal) ? 1 : 2)
373 #define advanceCols ((RhsIsReal) ? 1 : 2)
374
375 // PEEL_COMPLEX_MMA loop factor.
376 #define PEEL_COMPLEX_MMA 7
377
378 #define MICRO_COMPLEX_MMA_UNROLL(func) \
379 func(0) func(1) func(2) func(3) func(4)
380
381 #define MICRO_COMPLEX_MMA_LOAD_ONE(iter) \
382 if (unroll_factor > iter) { \
383 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
384 lhs_ptr_real##iter += accCols; \
385 if(!LhsIsReal) { \
386 lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_imag##iter); \
387 lhs_ptr_imag##iter += accCols; \
388 } else { \
389 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
390 } \
391 } else { \
392 EIGEN_UNUSED_VARIABLE(lhsV##iter); \
393 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
394 }
395
396 #define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \
397 if (unroll_factor > iter) { \
398 pgercMMA<Scalar, Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
399 }
400
401 #define MICRO_COMPLEX_MMA_TYPE_PEEL(func, func2, type, peel) \
402 if (PEEL_COMPLEX_MMA > peel) { \
403 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \
404 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \
405 ploadRhsMMA<Scalar, type>(rhs_ptr_real + (accRows * peel), rhsV##peel); \
406 if(!RhsIsReal) { \
407 ploadRhsMMA<Scalar, type>(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \
408 } else { \
409 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
410 } \
411 MICRO_COMPLEX_MMA_UNROLL(func2); \
412 func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) func(4,type,peel) \
413 } else { \
414 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
415 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
416 }
417
418 #define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
419 type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \
420 type rhsVi0, rhsVi1, rhsVi2, rhsVi3, rhsVi4, rhsVi5, rhsVi6, rhsVi7, rhsVi8, rhsVi9; \
421 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,1); \
422 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3); \
423 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,4); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,5); \
424 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,6); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,7); \
425 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,8); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,9);
426
427 #define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(func, func2, type) \
428 type rhsV0, rhsVi0; \
429 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0);
430
431 #define MICRO_COMPLEX_MMA_ONE_PEEL \
432 if (sizeof(Scalar) == sizeof(float)) { \
433 MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
434 } else { \
435 MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
436 } \
437 rhs_ptr_real += (accRows * PEEL_COMPLEX_MMA); \
438 if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_MMA);
439
440 #define MICRO_COMPLEX_MMA_ONE \
441 if (sizeof(Scalar) == sizeof(float)) { \
442 MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
443 } else { \
444 MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
445 } \
446 rhs_ptr_real += accRows; \
447 if(!RhsIsReal) rhs_ptr_imag += accRows;
448
449 #define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \
450 if (unroll_factor > iter) { \
451 bsetzeroMMA<Scalar, Packet>(&accReal##iter); \
452 bsetzeroMMA<Scalar, Packet>(&accImag##iter); \
453 } else { \
454 EIGEN_UNUSED_VARIABLE(accReal##iter); \
455 EIGEN_UNUSED_VARIABLE(accImag##iter); \
456 }
457
458 #define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE)
459
460 #define MICRO_COMPLEX_MMA_SRC_PTR_ONE(iter) \
461 if (unroll_factor > iter) { \
462 lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols + accCols*offsetA; \
463 if(!LhsIsReal) { \
464 lhs_ptr_imag##iter = lhs_ptr_real##iter + accCols*strideA; \
465 } else { \
466 EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
467 } \
468 } else { \
469 EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
470 EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
471 }
472
473 #define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_SRC_PTR_ONE)
474
475 #define MICRO_COMPLEX_MMA_PREFETCH_ONE(iter) \
476 if (unroll_factor > iter) { \
477 EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
478 if(!LhsIsReal) { \
479 EIGEN_POWER_PREFETCH(lhs_ptr_imag##iter); \
480 } \
481 }
482
483 #define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_PREFETCH_ONE)
484
485 #define MICRO_COMPLEX_MMA_STORE_ONE(iter) \
486 if (unroll_factor > iter) { \
487 storeComplexAccumulator<DataMapper, Index, Packet, Packetc, accColsC, 0>(row + iter*accCols, col, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \
488 }
489
490 #define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
491
492 template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
gemm_complex_unrolled_MMA_iteration(const DataMapper & res,const Scalar * lhs_base,const Scalar * rhs_base,Index depth,Index strideA,Index offsetA,Index strideB,Index & row,Index col,const Packet & pAlphaReal,const Packet & pAlphaImag)493 EIGEN_STRONG_INLINE void gemm_complex_unrolled_MMA_iteration(
494 const DataMapper& res,
495 const Scalar* lhs_base,
496 const Scalar* rhs_base,
497 Index depth,
498 Index strideA,
499 Index offsetA,
500 Index strideB,
501 Index& row,
502 Index col,
503 const Packet& pAlphaReal,
504 const Packet& pAlphaImag)
505 {
506 const Scalar* rhs_ptr_real = rhs_base;
507 const Scalar* rhs_ptr_imag;
508 if(!RhsIsReal) {
509 rhs_ptr_imag = rhs_base + accRows*strideB;
510 } else {
511 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
512 }
513 const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL;
514 const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL;
515 const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL;
516 __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3, accReal4, accImag4;
517
518 MICRO_COMPLEX_MMA_SRC_PTR
519 MICRO_COMPLEX_MMA_DST_PTR
520
521 Index k = 0;
522 for(; k + PEEL_COMPLEX_MMA <= depth; k+= PEEL_COMPLEX_MMA)
523 {
524 EIGEN_POWER_PREFETCH(rhs_ptr_real);
525 if(!RhsIsReal) {
526 EIGEN_POWER_PREFETCH(rhs_ptr_imag);
527 }
528 MICRO_COMPLEX_MMA_PREFETCH
529 MICRO_COMPLEX_MMA_ONE_PEEL
530 }
531 for(; k < depth; k++)
532 {
533 MICRO_COMPLEX_MMA_ONE
534 }
535 MICRO_COMPLEX_MMA_STORE
536
537 row += unroll_factor*accCols;
538 }
539
540 template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
gemm_complexMMA(const DataMapper & res,const LhsScalar * blockAc,const RhsScalar * blockBc,Index rows,Index depth,Index cols,Scalarc alpha,Index strideA,Index strideB,Index offsetA,Index offsetB)541 void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
542 {
543 const Index remaining_rows = rows % accCols;
544 const Index remaining_cols = cols % accRows;
545
546 if( strideA == -1 ) strideA = depth;
547 if( strideB == -1 ) strideB = depth;
548
549 const Packet pAlphaReal = pset1<Packet>(alpha.real());
550 const Packet pAlphaImag = pset1<Packet>(alpha.imag());
551 const Packet pMask = bmask<Packet>((const int)(remaining_rows));
552
553 const Scalar* blockA = (Scalar *) blockAc;
554 const Scalar* blockB = (Scalar *) blockBc;
555
556 Index col = 0;
557 for(; col + accRows <= cols; col += accRows)
558 {
559 const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
560 const Scalar* lhs_base = blockA;
561 Index row = 0;
562
563 #define MAX_COMPLEX_MMA_UNROLL 4
564 while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) {
565 gemm_complex_unrolled_MMA_iteration<MAX_COMPLEX_MMA_UNROLL, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
566 }
567 switch( (rows-row)/accCols ) {
568 #if MAX_COMPLEX_MMA_UNROLL > 4
569 case 4:
570 gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
571 break;
572 #endif
573 #if MAX_COMPLEX_MMA_UNROLL > 3
574 case 3:
575 gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
576 break;
577 #endif
578 #if MAX_COMPLEX_MMA_UNROLL > 2
579 case 2:
580 gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
581 break;
582 #endif
583 #if MAX_COMPLEX_MMA_UNROLL > 1
584 case 1:
585 gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
586 break;
587 #endif
588 default:
589 break;
590 }
591 #undef MAX_COMPLEX_MMA_UNROLL
592
593 if(remaining_rows > 0)
594 {
595 gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
596 }
597 }
598
599 if(remaining_cols > 0)
600 {
601 const Scalar* rhs_base = blockB + advanceCols*col*strideB + remaining_cols*offsetB;
602 const Scalar* lhs_base = blockA;
603
604 for(; col < cols; col++)
605 {
606 Index row = 0;
607
608 gemm_complex_unrolled_col<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, col, remaining_cols, pAlphaReal, pAlphaImag);
609
610 if (remaining_rows > 0)
611 {
612 gemm_complex_extra_col<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_rows, remaining_cols, pAlphaReal, pAlphaImag);
613 }
614 rhs_base++;
615 }
616 }
617 }
618
619 #undef accColsC
620 #undef advanceRows
621 #undef advanceCols
622
623 #pragma GCC reset_options
624 } // end namespace internal
625
626 } // end namespace Eigen
627
628 #endif // EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
629
630