xref: /aosp_15_r20/external/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2020 Everton Constantino ([email protected])
5 // Copyright (C) 2021 Chip Kerchner ([email protected])
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10 
11 #ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
12 #define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
13 
14 #pragma GCC target("cpu=power10")
15 
16 #ifdef __has_builtin
17 #if !__has_builtin(__builtin_vsx_assemble_pair)
18 #define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
19 #endif
20 #endif
21 
22 namespace Eigen {
23 
24 namespace internal {
25 
26 template<typename Scalar, typename Packet>
bsetzeroMMA(__vector_quad * acc)27 EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc)
28 {
29   __builtin_mma_xxsetaccz(acc);
30 }
31 
32 template<typename DataMapper, typename Index, typename Packet, const Index accCols>
storeAccumulator(Index i,Index j,const DataMapper & data,const Packet & alpha,__vector_quad * acc)33 EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, Index j, const DataMapper& data, const Packet& alpha, __vector_quad* acc)
34 {
35   PacketBlock<Packet, 4> result;
36   __builtin_mma_disassemble_acc(&result.packet, acc);
37 
38   PacketBlock<Packet, 4> tRes;
39   bload<DataMapper, Packet, Index, accCols, 0, ColMajor>(tRes, data, i, j);
40 
41   bscale<Packet>(tRes, result, alpha);
42 
43   data.template storePacketBlock<Packet, 4>(i, j, tRes);
44 }
45 
46 template<typename DataMapper, typename Index, typename Packet, typename Packetc, const Index accColsC, int N>
storeComplexAccumulator(Index i,Index j,const DataMapper & data,const Packet & alphaReal,const Packet & alphaImag,__vector_quad * accReal,__vector_quad * accImag)47 EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, Index j, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag)
48 {
49   PacketBlock<Packet, 4> resultReal, resultImag;
50   __builtin_mma_disassemble_acc(&resultReal.packet, accReal);
51   __builtin_mma_disassemble_acc(&resultImag.packet, accImag);
52 
53   PacketBlock<Packetc, 8> tRes;
54   bload<DataMapper, Packetc, Index, accColsC, N, ColMajor>(tRes, data, i, j);
55 
56   PacketBlock<Packet,4> taccReal, taccImag;
57   bscalec<Packet,4>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag);
58 
59   PacketBlock<Packetc, 4> acc1, acc2;
60   bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc1, acc2);
61 
62   data.template storePacketBlock<Packetc, 4>(i + N*accColsC, j, acc1);
63   data.template storePacketBlock<Packetc, 4>(i + (N+1)*accColsC, j, acc2);
64 }
65 
66 // Defaults to float32, since Eigen still supports C++03 we can't use default template arguments
67 template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
pgerMMA(__vector_quad * acc,const RhsPacket & a,const LhsPacket & b)68 EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const RhsPacket& a, const LhsPacket& b)
69 {
70   if(NegativeAccumulate)
71   {
72     __builtin_mma_xvf32gernp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
73   } else {
74     __builtin_mma_xvf32gerpp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
75   }
76 }
77 
78 template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
pgerMMA(__vector_quad * acc,const PacketBlock<Packet2d,2> & a,const Packet2d & b)79 EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const PacketBlock<Packet2d,2>& a, const Packet2d& b)
80 {
81   __vector_pair* a0 = (__vector_pair *)(&a.packet[0]);
82   if(NegativeAccumulate)
83   {
84     __builtin_mma_xvf64gernp(acc, *a0, (__vector unsigned char)b);
85   } else {
86     __builtin_mma_xvf64gerpp(acc, *a0, (__vector unsigned char)b);
87   }
88 }
89 
90 template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
pgerMMA(__vector_quad * acc,const __vector_pair & a,const Packet2d & b)91 EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const __vector_pair& a, const Packet2d& b)
92 {
93   if(NegativeAccumulate)
94   {
95     __builtin_mma_xvf64gernp(acc, (__vector_pair)a, (__vector unsigned char)b);
96   } else {
97     __builtin_mma_xvf64gerpp(acc, (__vector_pair)a, (__vector unsigned char)b);
98   }
99 }
100 
101 template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
pgerMMA(__vector_quad *,const __vector_pair &,const Packet4f &)102 EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad*, const __vector_pair&, const Packet4f&)
103 {
104   // Just for compilation
105 }
106 
107 template<typename Scalar, typename Packet, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
pgercMMA(__vector_quad * accReal,__vector_quad * accImag,const Packet & lhsV,const Packet & lhsVi,const RhsPacket & rhsV,const RhsPacket & rhsVi)108 EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag, const Packet& lhsV, const Packet& lhsVi, const RhsPacket& rhsV, const RhsPacket& rhsVi)
109 {
110   pgerMMA<Packet, RhsPacket, false>(accReal,  rhsV,  lhsV);
111   if(LhsIsReal) {
112     pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi,  lhsV);
113   } else {
114     if(!RhsIsReal) {
115       pgerMMA<Packet, RhsPacket, ConjugateLhs == ConjugateRhs>(accReal, rhsVi, lhsVi);
116       pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi,  lhsV);
117     } else {
118       EIGEN_UNUSED_VARIABLE(rhsVi);
119     }
120     pgerMMA<Packet, RhsPacket, ConjugateLhs>(accImag,  rhsV, lhsVi);
121   }
122 }
123 
124 // This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled.
125 template<typename Scalar, typename Packet>
ploadRhsMMA(const Scalar * rhs,Packet & rhsV)126 EIGEN_ALWAYS_INLINE void ploadRhsMMA(const Scalar* rhs, Packet& rhsV)
127 {
128   rhsV = ploadRhs<Scalar, Packet>((const Scalar*)(rhs));
129 }
130 
131 template<>
132 EIGEN_ALWAYS_INLINE void ploadRhsMMA<double, PacketBlock<Packet2d, 2> >(const double* rhs, PacketBlock<Packet2d, 2>& rhsV)
133 {
134   rhsV.packet[0] = ploadRhs<double, Packet2d>((const double *)((Packet2d *)rhs      ));
135   rhsV.packet[1] = ploadRhs<double, Packet2d>((const double *)(((Packet2d *)rhs) + 1));
136 }
137 
138 template<>
139 EIGEN_ALWAYS_INLINE void ploadRhsMMA<double, __vector_pair>(const double* rhs, __vector_pair& rhsV)
140 {
141 #if EIGEN_COMP_LLVM
142   __builtin_vsx_assemble_pair(&rhsV,
143     (__vector unsigned char)(ploadRhs<double, Packet2d>((const double *)(((Packet2d *)rhs) + 1))),
144     (__vector unsigned char)(ploadRhs<double, Packet2d>((const double *)((Packet2d *)rhs      ))));
145 #else
146   __asm__ ("lxvp %x0,%1" : "=wa" (rhsV) : "Y" (*rhs));
147 #endif
148 }
149 
150 template<>
ploadRhsMMA(const float *,__vector_pair &)151 EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&)
152 {
153   // Just for compilation
154 }
155 
156 // PEEL_MMA loop factor.
157 #define PEEL_MMA 7
158 
159 #define MICRO_MMA_UNROLL(func) \
160   func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
161 
162 #define MICRO_MMA_LOAD_ONE(iter) \
163   if (unroll_factor > iter) { \
164     lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr##iter); \
165     lhs_ptr##iter += accCols; \
166   } else { \
167     EIGEN_UNUSED_VARIABLE(lhsV##iter); \
168   }
169 
170 #define MICRO_MMA_WORK_ONE(iter, type, peel) \
171   if (unroll_factor > iter) { \
172     pgerMMA<Packet, type, false>(&accZero##iter, rhsV##peel, lhsV##iter); \
173   }
174 
175 #define MICRO_MMA_TYPE_PEEL(func, func2, type, peel) \
176   if (PEEL_MMA > peel) { \
177     Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
178     ploadRhsMMA<Scalar, type>(rhs_ptr + (accRows * peel), rhsV##peel); \
179     MICRO_MMA_UNROLL(func2); \
180     func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
181     func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel) \
182   } else { \
183     EIGEN_UNUSED_VARIABLE(rhsV##peel); \
184   }
185 
186 #define MICRO_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
187   type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \
188   MICRO_MMA_TYPE_PEEL(func,func2,type,0); MICRO_MMA_TYPE_PEEL(func,func2,type,1); \
189   MICRO_MMA_TYPE_PEEL(func,func2,type,2); MICRO_MMA_TYPE_PEEL(func,func2,type,3); \
190   MICRO_MMA_TYPE_PEEL(func,func2,type,4); MICRO_MMA_TYPE_PEEL(func,func2,type,5); \
191   MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7); \
192   MICRO_MMA_TYPE_PEEL(func,func2,type,8); MICRO_MMA_TYPE_PEEL(func,func2,type,9);
193 
194 #define MICRO_MMA_UNROLL_TYPE_ONE(func, func2, type) \
195   type rhsV0; \
196   MICRO_MMA_TYPE_PEEL(func,func2,type,0);
197 
198 #define MICRO_MMA_ONE_PEEL \
199   if (sizeof(Scalar) == sizeof(float)) { \
200     MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
201   } else { \
202     MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
203   } \
204   rhs_ptr += (accRows * PEEL_MMA);
205 
206 #define MICRO_MMA_ONE \
207   if (sizeof(Scalar) == sizeof(float)) { \
208     MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
209   } else { \
210     MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
211   } \
212   rhs_ptr += accRows;
213 
214 #define MICRO_MMA_DST_PTR_ONE(iter) \
215   if (unroll_factor > iter) { \
216     bsetzeroMMA<Scalar, Packet>(&accZero##iter); \
217   } else { \
218     EIGEN_UNUSED_VARIABLE(accZero##iter); \
219   }
220 
221 #define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE)
222 
223 #define MICRO_MMA_SRC_PTR_ONE(iter) \
224   if (unroll_factor > iter) { \
225     lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \
226   } else { \
227     EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
228   }
229 
230 #define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_MMA_SRC_PTR_ONE)
231 
232 #define MICRO_MMA_PREFETCH_ONE(iter) \
233   if (unroll_factor > iter) { \
234     EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
235   }
236 
237 #define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_MMA_PREFETCH_ONE)
238 
239 #define MICRO_MMA_STORE_ONE(iter) \
240   if (unroll_factor > iter) { \
241     storeAccumulator<DataMapper, Index, Packet, accCols>(row + iter*accCols, col, res, pAlpha, &accZero##iter); \
242   }
243 
244 #define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE)
245 
246 template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols>
gemm_unrolled_MMA_iteration(const DataMapper & res,const Scalar * lhs_base,const Scalar * rhs_base,Index depth,Index strideA,Index offsetA,Index & row,Index col,const Packet & pAlpha)247 EIGEN_STRONG_INLINE void gemm_unrolled_MMA_iteration(
248   const DataMapper& res,
249   const Scalar* lhs_base,
250   const Scalar* rhs_base,
251   Index depth,
252   Index strideA,
253   Index offsetA,
254   Index& row,
255   Index col,
256   const Packet& pAlpha)
257 {
258   const Scalar* rhs_ptr = rhs_base;
259   const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
260   __vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
261 
262   MICRO_MMA_SRC_PTR
263   MICRO_MMA_DST_PTR
264 
265   Index k = 0;
266   for(; k + PEEL_MMA <= depth; k+= PEEL_MMA)
267   {
268     EIGEN_POWER_PREFETCH(rhs_ptr);
269     MICRO_MMA_PREFETCH
270     MICRO_MMA_ONE_PEEL
271   }
272   for(; k < depth; k++)
273   {
274     MICRO_MMA_ONE
275   }
276   MICRO_MMA_STORE
277 
278   row += unroll_factor*accCols;
279 }
280 
281 template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
gemmMMA(const DataMapper & res,const Scalar * blockA,const Scalar * blockB,Index rows,Index depth,Index cols,Scalar alpha,Index strideA,Index strideB,Index offsetA,Index offsetB)282 void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
283 {
284       const Index remaining_rows = rows % accCols;
285       const Index remaining_cols = cols % accRows;
286 
287       if( strideA == -1 ) strideA = depth;
288       if( strideB == -1 ) strideB = depth;
289 
290       const Packet pAlpha = pset1<Packet>(alpha);
291       const Packet pMask  = bmask<Packet>((const int)(remaining_rows));
292 
293       Index col = 0;
294       for(; col + accRows <= cols; col += accRows)
295       {
296         const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
297         const Scalar* lhs_base = blockA;
298 
299         Index row = 0;
300 #define MAX_MMA_UNROLL 7
301         while(row + MAX_MMA_UNROLL*accCols <= rows) {
302           gemm_unrolled_MMA_iteration<MAX_MMA_UNROLL, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
303         }
304         switch( (rows-row)/accCols ) {
305 #if MAX_MMA_UNROLL > 7
306           case 7:
307             gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
308             break;
309 #endif
310 #if MAX_MMA_UNROLL > 6
311           case 6:
312             gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
313             break;
314 #endif
315 #if MAX_MMA_UNROLL > 5
316           case 5:
317             gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
318             break;
319 #endif
320 #if MAX_MMA_UNROLL > 4
321           case 4:
322             gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
323             break;
324 #endif
325 #if MAX_MMA_UNROLL > 3
326           case 3:
327             gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
328             break;
329 #endif
330 #if MAX_MMA_UNROLL > 2
331           case 2:
332             gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
333             break;
334 #endif
335 #if MAX_MMA_UNROLL > 1
336           case 1:
337             gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
338             break;
339 #endif
340           default:
341             break;
342         }
343 #undef MAX_MMA_UNROLL
344 
345         if(remaining_rows > 0)
346         {
347           gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
348         }
349       }
350 
351       if(remaining_cols > 0)
352       {
353         const Scalar* rhs_base = blockB + col*strideB + remaining_cols*offsetB;
354         const Scalar* lhs_base = blockA;
355 
356         for(; col < cols; col++)
357         {
358           Index row = 0;
359 
360           gemm_unrolled_col<Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha);
361 
362           if (remaining_rows > 0)
363           {
364             gemm_extra_col<Scalar, Packet, DataMapper, Index, accRows>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha);
365           }
366           rhs_base++;
367         }
368       }
369 }
370 
371 #define accColsC (accCols / 2)
372 #define advanceRows ((LhsIsReal) ? 1 : 2)
373 #define advanceCols ((RhsIsReal) ? 1 : 2)
374 
375 // PEEL_COMPLEX_MMA loop factor.
376 #define PEEL_COMPLEX_MMA 7
377 
378 #define MICRO_COMPLEX_MMA_UNROLL(func) \
379   func(0) func(1) func(2) func(3) func(4)
380 
381 #define MICRO_COMPLEX_MMA_LOAD_ONE(iter) \
382   if (unroll_factor > iter) { \
383     lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
384     lhs_ptr_real##iter += accCols; \
385     if(!LhsIsReal) { \
386       lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_imag##iter); \
387       lhs_ptr_imag##iter += accCols; \
388     } else { \
389       EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
390     } \
391   } else { \
392     EIGEN_UNUSED_VARIABLE(lhsV##iter); \
393     EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
394   }
395 
396 #define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \
397   if (unroll_factor > iter) { \
398     pgercMMA<Scalar, Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
399   }
400 
401 #define MICRO_COMPLEX_MMA_TYPE_PEEL(func, func2, type, peel) \
402   if (PEEL_COMPLEX_MMA > peel) { \
403     Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \
404     Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \
405     ploadRhsMMA<Scalar, type>(rhs_ptr_real + (accRows * peel), rhsV##peel); \
406     if(!RhsIsReal) { \
407       ploadRhsMMA<Scalar, type>(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \
408     } else { \
409       EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
410     } \
411     MICRO_COMPLEX_MMA_UNROLL(func2); \
412     func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) func(4,type,peel) \
413   } else { \
414     EIGEN_UNUSED_VARIABLE(rhsV##peel); \
415     EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
416   }
417 
418 #define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
419   type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \
420   type rhsVi0, rhsVi1, rhsVi2, rhsVi3, rhsVi4, rhsVi5, rhsVi6, rhsVi7, rhsVi8, rhsVi9; \
421   MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,1); \
422   MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3); \
423   MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,4); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,5); \
424   MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,6); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,7); \
425   MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,8); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,9);
426 
427 #define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(func, func2, type) \
428   type rhsV0, rhsVi0; \
429   MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0);
430 
431 #define MICRO_COMPLEX_MMA_ONE_PEEL \
432   if (sizeof(Scalar) == sizeof(float)) { \
433     MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
434   } else { \
435     MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
436   } \
437   rhs_ptr_real += (accRows * PEEL_COMPLEX_MMA); \
438   if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_MMA);
439 
440 #define MICRO_COMPLEX_MMA_ONE \
441   if (sizeof(Scalar) == sizeof(float)) { \
442     MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
443   } else { \
444     MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
445   } \
446   rhs_ptr_real += accRows; \
447   if(!RhsIsReal) rhs_ptr_imag += accRows;
448 
449 #define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \
450   if (unroll_factor > iter) { \
451     bsetzeroMMA<Scalar, Packet>(&accReal##iter); \
452     bsetzeroMMA<Scalar, Packet>(&accImag##iter); \
453   } else { \
454     EIGEN_UNUSED_VARIABLE(accReal##iter); \
455     EIGEN_UNUSED_VARIABLE(accImag##iter); \
456   }
457 
458 #define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE)
459 
460 #define MICRO_COMPLEX_MMA_SRC_PTR_ONE(iter) \
461   if (unroll_factor > iter) { \
462     lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols + accCols*offsetA; \
463     if(!LhsIsReal) { \
464       lhs_ptr_imag##iter = lhs_ptr_real##iter + accCols*strideA; \
465     } else { \
466       EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
467     } \
468   } else { \
469     EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
470     EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
471   }
472 
473 #define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_SRC_PTR_ONE)
474 
475 #define MICRO_COMPLEX_MMA_PREFETCH_ONE(iter) \
476   if (unroll_factor > iter) { \
477     EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
478     if(!LhsIsReal) { \
479       EIGEN_POWER_PREFETCH(lhs_ptr_imag##iter); \
480     } \
481   }
482 
483 #define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_PREFETCH_ONE)
484 
485 #define MICRO_COMPLEX_MMA_STORE_ONE(iter) \
486   if (unroll_factor > iter) { \
487     storeComplexAccumulator<DataMapper, Index, Packet, Packetc, accColsC, 0>(row + iter*accCols, col, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \
488   }
489 
490 #define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
491 
492 template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
gemm_complex_unrolled_MMA_iteration(const DataMapper & res,const Scalar * lhs_base,const Scalar * rhs_base,Index depth,Index strideA,Index offsetA,Index strideB,Index & row,Index col,const Packet & pAlphaReal,const Packet & pAlphaImag)493 EIGEN_STRONG_INLINE void gemm_complex_unrolled_MMA_iteration(
494   const DataMapper& res,
495   const Scalar* lhs_base,
496   const Scalar* rhs_base,
497   Index depth,
498   Index strideA,
499   Index offsetA,
500   Index strideB,
501   Index& row,
502   Index col,
503   const Packet& pAlphaReal,
504   const Packet& pAlphaImag)
505 {
506   const Scalar* rhs_ptr_real = rhs_base;
507   const Scalar* rhs_ptr_imag;
508   if(!RhsIsReal) {
509     rhs_ptr_imag = rhs_base + accRows*strideB;
510   } else {
511     EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
512   }
513   const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL;
514   const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL;
515   const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL;
516   __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3, accReal4, accImag4;
517 
518   MICRO_COMPLEX_MMA_SRC_PTR
519   MICRO_COMPLEX_MMA_DST_PTR
520 
521   Index k = 0;
522   for(; k + PEEL_COMPLEX_MMA <= depth; k+= PEEL_COMPLEX_MMA)
523   {
524     EIGEN_POWER_PREFETCH(rhs_ptr_real);
525     if(!RhsIsReal) {
526       EIGEN_POWER_PREFETCH(rhs_ptr_imag);
527     }
528     MICRO_COMPLEX_MMA_PREFETCH
529     MICRO_COMPLEX_MMA_ONE_PEEL
530   }
531   for(; k < depth; k++)
532   {
533     MICRO_COMPLEX_MMA_ONE
534   }
535   MICRO_COMPLEX_MMA_STORE
536 
537   row += unroll_factor*accCols;
538 }
539 
540 template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
gemm_complexMMA(const DataMapper & res,const LhsScalar * blockAc,const RhsScalar * blockBc,Index rows,Index depth,Index cols,Scalarc alpha,Index strideA,Index strideB,Index offsetA,Index offsetB)541 void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
542 {
543       const Index remaining_rows = rows % accCols;
544       const Index remaining_cols = cols % accRows;
545 
546       if( strideA == -1 ) strideA = depth;
547       if( strideB == -1 ) strideB = depth;
548 
549       const Packet pAlphaReal = pset1<Packet>(alpha.real());
550       const Packet pAlphaImag = pset1<Packet>(alpha.imag());
551       const Packet pMask = bmask<Packet>((const int)(remaining_rows));
552 
553       const Scalar* blockA = (Scalar *) blockAc;
554       const Scalar* blockB = (Scalar *) blockBc;
555 
556       Index col = 0;
557       for(; col + accRows <= cols; col += accRows)
558       {
559         const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
560         const Scalar* lhs_base = blockA;
561         Index row = 0;
562 
563 #define MAX_COMPLEX_MMA_UNROLL 4
564         while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) {
565           gemm_complex_unrolled_MMA_iteration<MAX_COMPLEX_MMA_UNROLL, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
566         }
567         switch( (rows-row)/accCols ) {
568 #if MAX_COMPLEX_MMA_UNROLL > 4
569           case 4:
570             gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
571             break;
572 #endif
573 #if MAX_COMPLEX_MMA_UNROLL > 3
574           case 3:
575             gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
576             break;
577 #endif
578 #if MAX_COMPLEX_MMA_UNROLL > 2
579           case 2:
580             gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
581             break;
582 #endif
583 #if MAX_COMPLEX_MMA_UNROLL > 1
584           case 1:
585             gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
586             break;
587 #endif
588           default:
589             break;
590         }
591 #undef MAX_COMPLEX_MMA_UNROLL
592 
593         if(remaining_rows > 0)
594         {
595           gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
596         }
597       }
598 
599       if(remaining_cols > 0)
600       {
601         const Scalar* rhs_base = blockB + advanceCols*col*strideB + remaining_cols*offsetB;
602         const Scalar* lhs_base = blockA;
603 
604         for(; col < cols; col++)
605         {
606           Index row = 0;
607 
608           gemm_complex_unrolled_col<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, col, remaining_cols, pAlphaReal, pAlphaImag);
609 
610           if (remaining_rows > 0)
611           {
612             gemm_complex_extra_col<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_rows, remaining_cols, pAlphaReal, pAlphaImag);
613           }
614           rhs_base++;
615         }
616       }
617 }
618 
619 #undef accColsC
620 #undef advanceRows
621 #undef advanceCols
622 
623 #pragma GCC reset_options
624 } // end namespace internal
625 
626 } // end namespace Eigen
627 
628 #endif // EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
629 
630