xref: /aosp_15_r20/external/gemmlowp/internal/simd_wrappers.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1 // Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // simd_wrappers.h: some inline functions wrapping SIMD intrinsics,
16 // extending the set of such functions from fixedpoint.h.
17 
18 #ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
19 #define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
20 
21 #include <algorithm>
22 #include <type_traits>
23 #include "../fixedpoint/fixedpoint.h"
24 
25 namespace gemmlowp {
26 
27 template <typename ScalarType, int ScalarCount>
28 struct RegisterType {
29   using Type = ScalarType;
30 };
31 
Min(std::int32_t a,std::int32_t b)32 inline std::int32_t Min(std::int32_t a, std::int32_t b) {
33   return std::min(a, b);
34 }
35 
Max(std::int32_t a,std::int32_t b)36 inline std::int32_t Max(std::int32_t a, std::int32_t b) {
37   return std::max(a, b);
38 }
39 
MulAdd(std::int32_t lhs,std::int32_t rhs,std::int32_t * acc)40 inline void MulAdd(std::int32_t lhs, std::int32_t rhs, std::int32_t* acc) {
41   *acc += lhs * rhs;
42 }
43 
44 template <typename tScalarType, int tScalarCount>
45 struct RegisterBuffer {
46   using ScalarType = tScalarType;
47   static constexpr int kScalarCount = tScalarCount;
48   using RegisterType = typename RegisterType<ScalarType, kScalarCount>::Type;
49   static_assert((kScalarCount & (kScalarCount - 1)) == 0,
50                 "kScalarCount must be a power of two");
51   static_assert(sizeof(RegisterType) % sizeof(ScalarType) == 0, "");
52   static constexpr int kRegisterLanes =
53       sizeof(RegisterType) / sizeof(ScalarType);
54   static constexpr int kRegisterCount =
55       (kScalarCount * sizeof(ScalarType) + sizeof(RegisterType) - 1) /
56       sizeof(RegisterType);
57 
58   RegisterType reg[kRegisterCount];
59 };
60 
61 template <typename tScalarType, int tRows, int tCols>
62 struct RegisterBlock {
63   using ScalarType = tScalarType;
64   static constexpr int kRows = tRows;
65   static constexpr int kCols = tCols;
66   static constexpr int kScalarCount = kRows * kCols;
67   using BufferType = RegisterBuffer<ScalarType, kScalarCount>;
68   using RegisterType = typename BufferType::RegisterType;
69   static constexpr int kRegisterCount = BufferType::kRegisterCount;
70   static constexpr int kRegisterLanes = BufferType::kRegisterLanes;
71 
72   BufferType buf;
73 };
74 
75 template <typename RegisterBlockType>
76 struct RegisterBlockAddImpl {
RunRegisterBlockAddImpl77   static RegisterBlockType Run(const RegisterBlockType& lhs,
78                                const RegisterBlockType& rhs) {
79     RegisterBlockType result;
80     for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
81       result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]);
82     }
83     return result;
84   }
85 };
86 
87 template <typename RegisterBlockType>
RegisterBlockAdd(const RegisterBlockType & lhs,const RegisterBlockType & rhs)88 RegisterBlockType RegisterBlockAdd(const RegisterBlockType& lhs,
89                                    const RegisterBlockType& rhs) {
90   return RegisterBlockAddImpl<RegisterBlockType>::Run(lhs, rhs);
91 }
92 
93 template <typename LhsType, typename RhsType>
94 struct ShouldFlipLhsRhs {
95   static constexpr bool kValue =
96       (LhsType::kScalarCount < RhsType::kScalarCount) ||
97       (LhsType::kScalarCount == RhsType::kScalarCount &&
98        (LhsType::kRows < RhsType::kRows));
99 };
100 
101 template <typename LhsType, typename RhsType,
102           bool Flip = ShouldFlipLhsRhs<LhsType, RhsType>::kValue>
103 struct FlipLhsRhs {
104   using FlippedLhsType = LhsType;
105   using FlippedRhsType = RhsType;
FlippedLhsFlipLhsRhs106   static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
107                                           const RhsType& rhs) {
108     (void)rhs;
109     return lhs;
110   }
FlippedRhsFlipLhsRhs111   static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
112                                           const RhsType& rhs) {
113     (void)lhs;
114     return rhs;
115   }
116 };
117 
118 template <typename LhsType, typename RhsType>
119 struct FlipLhsRhs<LhsType, RhsType, true> {
120   using FlippedLhsType = RhsType;
121   using FlippedRhsType = LhsType;
122   static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
123                                           const RhsType& rhs) {
124     (void)lhs;
125     return rhs;
126   }
127   static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
128                                           const RhsType& rhs) {
129     (void)rhs;
130     return lhs;
131   }
132 };
133 
134 template <typename Lhs, typename Rhs>
135 struct BroadcastBinaryOpShape {
136   static constexpr int kRows =
137       Lhs::kRows > Rhs::kRows ? Lhs::kRows : Rhs::kRows;
138   static constexpr int kCols =
139       Lhs::kCols > Rhs::kCols ? Lhs::kCols : Rhs::kCols;
140 };
141 
142 template <typename Lhs, typename Rhs>
143 struct BroadcastBinaryOpRegisterBlock {
144   using Shape = BroadcastBinaryOpShape<Lhs, Rhs>;
145   using ScalarType = typename Lhs::ScalarType;
146   using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
147 };
148 
149 template <typename Lhs, typename Rhs>
150 struct BroadcastAddImpl {
151   using ResultBlockType =
152       typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
153   static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
154     ResultBlockType result;
155     static constexpr int Rows = ResultBlockType::kRows;
156     static constexpr int Cols = ResultBlockType::kCols;
157     static constexpr int LhsRows = Lhs::kRows;
158     static constexpr int LhsCols = Lhs::kCols;
159     static constexpr int RhsRows = Rhs::kRows;
160     static constexpr int RhsCols = Rhs::kCols;
161 
162     static_assert(LhsRows == Rows || LhsRows == 1, "");
163     static_assert(RhsRows == Rows || RhsRows == 1, "");
164     static_assert(LhsCols == Cols || LhsCols == 1, "");
165     static_assert(RhsCols == Cols || RhsCols == 1, "");
166     static_assert(ResultBlockType::kRegisterLanes == 1,
167                   "This path is only for scalar values");
168     static_assert(Lhs::kRegisterLanes == 1,
169                   "This path is only for scalar values");
170     static_assert(Rhs::kRegisterLanes == 1,
171                   "This path is only for scalar values");
172 
173     for (int c = 0; c < Cols; c++) {
174       const int lhs_c = LhsCols == Cols ? c : 0;
175       const int rhs_c = RhsCols == Cols ? c : 0;
176       for (int r = 0; r < Rows; r++) {
177         const int lhs_r = LhsRows == Rows ? r : 0;
178         const int rhs_r = RhsRows == Rows ? r : 0;
179         result.buf.reg[r + c * Rows] =
180             Add(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
181                 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
182       }
183     }
184     return result;
185   }
186 };
187 
188 template <typename Lhs, typename Rhs>
189 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastAdd(
190     const Lhs& lhs, const Rhs& rhs) {
191   using Flip = FlipLhsRhs<Lhs, Rhs>;
192   return BroadcastAddImpl<
193       typename Flip::FlippedLhsType,
194       typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
195                                           Flip::FlippedRhs(lhs, rhs));
196 }
197 
198 template <typename Lhs, typename Rhs>
199 struct BroadcastShiftLeftImpl {
200   using ResultBlockType =
201       typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
202   static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
203     ResultBlockType result;
204     static constexpr int Rows = ResultBlockType::kRows;
205     static constexpr int Cols = ResultBlockType::kCols;
206     static constexpr int LhsRows = Lhs::kRows;
207     static constexpr int LhsCols = Lhs::kCols;
208     static constexpr int RhsRows = Rhs::kRows;
209     static constexpr int RhsCols = Rhs::kCols;
210 
211     static_assert(LhsRows == Rows || LhsRows == 1, "");
212     static_assert(RhsRows == Rows || RhsRows == 1, "");
213     static_assert(LhsCols == Cols || LhsCols == 1, "");
214     static_assert(RhsCols == Cols || RhsCols == 1, "");
215     static_assert(ResultBlockType::kRegisterLanes == 1,
216                   "This path is only for scalar values");
217     static_assert(Lhs::kRegisterLanes == 1,
218                   "This path is only for scalar values");
219     static_assert(Rhs::kRegisterLanes == 1,
220                   "This path is only for scalar values");
221 
222     for (int c = 0; c < Cols; c++) {
223       const int lhs_c = LhsCols == Cols ? c : 0;
224       const int rhs_c = RhsCols == Cols ? c : 0;
225       for (int r = 0; r < Rows; r++) {
226         const int lhs_r = LhsRows == Rows ? r : 0;
227         const int rhs_r = RhsRows == Rows ? r : 0;
228         result.buf.reg[r + c * Rows] =
229             ShiftLeft(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
230                       rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
231       }
232     }
233     return result;
234   }
235 };
236 
237 template <typename Lhs, typename Rhs>
238 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastShiftLeft(
239     const Lhs& lhs, const Rhs& rhs) {
240   using Flip = FlipLhsRhs<Lhs, Rhs>;
241   return BroadcastShiftLeftImpl<
242       typename Flip::FlippedLhsType,
243       typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
244                                           Flip::FlippedRhs(lhs, rhs));
245 }
246 
247 template <typename Lhs, typename Rhs>
248 struct BroadcastSaturatingRoundingDoublingHighMulImpl {
249   using ResultBlockType =
250       typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
251   static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
252     ResultBlockType result;
253     static constexpr int Rows = ResultBlockType::kRows;
254     static constexpr int Cols = ResultBlockType::kCols;
255     static constexpr int LhsRows = Lhs::kRows;
256     static constexpr int LhsCols = Lhs::kCols;
257     static constexpr int RhsRows = Rhs::kRows;
258     static constexpr int RhsCols = Rhs::kCols;
259 
260     static_assert(LhsRows == Rows || LhsRows == 1, "");
261     static_assert(RhsRows == Rows || RhsRows == 1, "");
262     static_assert(LhsCols == Cols || LhsCols == 1, "");
263     static_assert(RhsCols == Cols || RhsCols == 1, "");
264     static_assert(ResultBlockType::kRegisterLanes == 1,
265                   "This path is only for scalar values");
266     static_assert(Lhs::kRegisterLanes == 1,
267                   "This path is only for scalar values");
268     static_assert(Rhs::kRegisterLanes == 1,
269                   "This path is only for scalar values");
270 
271     for (int c = 0; c < Cols; c++) {
272       const int lhs_c = LhsCols == Cols ? c : 0;
273       const int rhs_c = RhsCols == Cols ? c : 0;
274       for (int r = 0; r < Rows; r++) {
275         const int lhs_r = LhsRows == Rows ? r : 0;
276         const int rhs_r = RhsRows == Rows ? r : 0;
277         result.buf.reg[r + c * Rows] = SaturatingRoundingDoublingHighMul(
278             lhs.buf.reg[lhs_r + lhs_c * LhsRows],
279             rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
280       }
281     }
282     return result;
283   }
284 };
285 
286 template <typename Lhs, typename Rhs>
287 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type
288 BroadcastSaturatingRoundingDoublingHighMul(const Lhs& lhs, const Rhs& rhs) {
289   using Flip = FlipLhsRhs<Lhs, Rhs>;
290   return BroadcastSaturatingRoundingDoublingHighMulImpl<
291       typename Flip::FlippedLhsType,
292       typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
293                                           Flip::FlippedRhs(lhs, rhs));
294 }
295 
296 template <typename Lhs, typename Rhs>
297 struct BroadcastRoundingDivideByPOTImpl {
298   using ResultBlockType =
299       typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
300   static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
301     ResultBlockType result;
302     static constexpr int Rows = ResultBlockType::kRows;
303     static constexpr int Cols = ResultBlockType::kCols;
304     static constexpr int LhsRows = Lhs::kRows;
305     static constexpr int LhsCols = Lhs::kCols;
306     static constexpr int RhsRows = Rhs::kRows;
307     static constexpr int RhsCols = Rhs::kCols;
308 
309     static_assert(LhsRows == Rows || LhsRows == 1, "");
310     static_assert(RhsRows == Rows || RhsRows == 1, "");
311     static_assert(LhsCols == Cols || LhsCols == 1, "");
312     static_assert(RhsCols == Cols || RhsCols == 1, "");
313     static_assert(ResultBlockType::kRegisterLanes == 1,
314                   "This path is only for scalar values");
315     static_assert(Lhs::kRegisterLanes == 1,
316                   "This path is only for scalar values");
317     static_assert(Rhs::kRegisterLanes == 1,
318                   "This path is only for scalar values");
319 
320     for (int c = 0; c < Cols; c++) {
321       const int lhs_c = LhsCols == Cols ? c : 0;
322       const int rhs_c = RhsCols == Cols ? c : 0;
323       for (int r = 0; r < Rows; r++) {
324         const int lhs_r = LhsRows == Rows ? r : 0;
325         const int rhs_r = RhsRows == Rows ? r : 0;
326         result.buf.reg[r + c * Rows] =
327             RoundingDivideByPOT(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
328                                 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
329       }
330     }
331     return result;
332   }
333 };
334 
335 template <typename Lhs, typename Rhs>
336 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type
337 BroadcastRoundingDivideByPOT(const Lhs& lhs, const Rhs& rhs) {
338   using Flip = FlipLhsRhs<Lhs, Rhs>;
339   return BroadcastRoundingDivideByPOTImpl<
340       typename Flip::FlippedLhsType,
341       typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
342                                           Flip::FlippedRhs(lhs, rhs));
343 }
344 
345 template <typename Lhs, typename Rhs>
346 struct BroadcastMulImpl {
347   using ResultBlockType =
348       typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
349   static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
350     ResultBlockType result;
351     static constexpr int Rows = ResultBlockType::kRows;
352     static constexpr int Cols = ResultBlockType::kCols;
353     static constexpr int LhsRows = Lhs::kRows;
354     static constexpr int LhsCols = Lhs::kCols;
355     static constexpr int RhsRows = Rhs::kRows;
356     static constexpr int RhsCols = Rhs::kCols;
357     static_assert(ResultBlockType::kRegisterLanes == 1,
358                   "This path is only for scalar values");
359     static_assert(Lhs::kRegisterLanes == 1,
360                   "This path is only for scalar values");
361     static_assert(Rhs::kRegisterLanes == 1,
362                   "This path is only for scalar values");
363 
364     static_assert(LhsRows == Rows || LhsRows == 1, "");
365     static_assert(RhsRows == Rows || RhsRows == 1, "");
366     static_assert(LhsCols == Cols || LhsCols == 1, "");
367     static_assert(RhsCols == Cols || RhsCols == 1, "");
368     for (int c = 0; c < Cols; c++) {
369       const int lhs_c = LhsCols == Cols ? c : 0;
370       const int rhs_c = RhsCols == Cols ? c : 0;
371       for (int r = 0; r < Rows; r++) {
372         const int lhs_r = LhsRows == Rows ? r : 0;
373         const int rhs_r = RhsRows == Rows ? r : 0;
374         result.buf.reg[r + c * Rows] =
375             Mul(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
376                 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
377       }
378     }
379     return result;
380   }
381 };
382 
383 template <typename Lhs, typename Rhs>
384 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastMul(
385     const Lhs& lhs, const Rhs& rhs) {
386   using Flip = FlipLhsRhs<Lhs, Rhs>;
387   return BroadcastMulImpl<
388       typename Flip::FlippedLhsType,
389       typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
390                                           Flip::FlippedRhs(lhs, rhs));
391 }
392 
393 template <typename Lhs, typename Rhs, typename Acc>
394 struct BroadcastMulAddImpl {
395   static void Run(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
396     static constexpr int Rows = Acc::kRows;
397     static constexpr int Cols = Acc::kCols;
398     static constexpr int LhsRows = Lhs::kRows;
399     static constexpr int LhsCols = Lhs::kCols;
400     static constexpr int RhsRows = Rhs::kRows;
401     static constexpr int RhsCols = Rhs::kCols;
402     static_assert(Acc::kRegisterLanes == 1,
403                   "This path is only for scalar values");
404     static_assert(Lhs::kRegisterLanes == 1,
405                   "This path is only for scalar values");
406     static_assert(Rhs::kRegisterLanes == 1,
407                   "This path is only for scalar values");
408 
409     static_assert(LhsRows == Rows || LhsRows == 1, "");
410     static_assert(RhsRows == Rows || RhsRows == 1, "");
411     static_assert(LhsCols == Cols || LhsCols == 1, "");
412     static_assert(RhsCols == Cols || RhsCols == 1, "");
413     for (int c = 0; c < Cols; c++) {
414       const int lhs_c = LhsCols == Cols ? c : 0;
415       const int rhs_c = RhsCols == Cols ? c : 0;
416       for (int r = 0; r < Rows; r++) {
417         const int lhs_r = LhsRows == Rows ? r : 0;
418         const int rhs_r = RhsRows == Rows ? r : 0;
419         MulAdd(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
420                rhs.buf.reg[rhs_r + rhs_c * RhsRows],
421                &acc->buf.reg[r + c * Rows]);
422       }
423     }
424   }
425 };
426 
427 template <typename Lhs, typename Rhs, typename Acc>
428 void BroadcastMulAdd(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
429   using Flip = FlipLhsRhs<Lhs, Rhs>;
430   BroadcastMulAddImpl<typename Flip::FlippedLhsType,
431                       typename Flip::FlippedRhsType,
432                       Acc>::Run(Flip::FlippedLhs(lhs, rhs),
433                                 Flip::FlippedRhs(lhs, rhs), acc);
434 }
435 
436 template <typename RegisterBlockType, typename SrcObjectType>
437 struct LoadImpl {
438   static_assert(std::is_same<SrcObjectType, void>::value,
439                 "This generic impl should never be hit");
440 };
441 
442 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType>
443 struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
444                 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
445   using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
446   using SrcObjectType = MatrixMap<SrcScalarType, MapOrder::ColMajor>;
447   static RegisterBlockType Run(const SrcObjectType& src, int row, int col) {
448     RegisterBlockType result;
449     int i = 0;
450     for (int c = 0; c < Cols; c++) {
451       const ScalarType* src_ptr = src.data(row, col + c);
452       for (int r = 0; r < Rows; r++) {
453         result.buf.reg[i++] = *src_ptr++;
454       }
455     }
456     return result;
457   }
458 };
459 
460 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
461           VectorShape Shape>
462 struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
463                 VectorMap<SrcScalarType, Shape>> {
464   using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
465   using SrcObjectType = VectorMap<SrcScalarType, Shape>;
466   static RegisterBlockType Run(const SrcObjectType& src, int pos) {
467     static_assert(Shape == VectorShape::Col || Rows == 1, "");
468     static_assert(Shape == VectorShape::Row || Cols == 1, "");
469     RegisterBlockType result;
470     for (int i = 0; i < Rows * Cols; i++) {
471       result.buf.reg[i] = src(pos + i);
472     }
473     return result;
474   }
475 };
476 
477 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
478           VectorShape Shape>
479 struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
480                 VectorDup<SrcScalarType, Shape>> {
481   using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
482   using SrcObjectType = VectorDup<SrcScalarType, Shape>;
483   static RegisterBlockType Run(const SrcObjectType& src, int) {
484     static_assert(Shape == VectorShape::Col || Rows == 1, "");
485     static_assert(Shape == VectorShape::Row || Cols == 1, "");
486     RegisterBlockType result;
487     for (int i = 0; i < Rows * Cols; i++) {
488       result.buf.reg[i] = src(0);
489     }
490     return result;
491   }
492 };
493 
494 template <typename RegisterBlockType, typename SrcObjectType>
495 RegisterBlockType Load(const SrcObjectType& src, int row, int col) {
496   return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, row, col);
497 }
498 
499 template <typename RegisterBlockType, typename SrcObjectType>
500 RegisterBlockType Load(const SrcObjectType& src, int pos) {
501   return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, pos);
502 }
503 
504 template <typename RegisterBlockType>
505 struct LoadContiguousImpl {
506   using ScalarType = typename RegisterBlockType::ScalarType;
507   static_assert(RegisterBlockType::kRegisterLanes == 1,
508                 "This path is only for scalar values");
509   static RegisterBlockType Run(const ScalarType* src) {
510     RegisterBlockType result;
511     for (int i = 0; i < RegisterBlockType::kScalarCount; i++) {
512       result.buf.reg[i] = src[i];
513     }
514     return result;
515   }
516 };
517 
518 template <typename RegisterBlockType>
519 RegisterBlockType LoadContiguous(
520     const typename RegisterBlockType::ScalarType* src) {
521   return LoadContiguousImpl<RegisterBlockType>::Run(src);
522 }
523 
524 template <int BroadcastRows, int BroadcastCols, typename SrcObjectType>
525 struct LoadForBroadcastingShape {};
526 
527 template <int BroadcastRows, int BroadcastCols, typename ScalarType,
528           VectorShape Shape>
529 struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
530                                 VectorMap<ScalarType, Shape>> {
531   static constexpr int kRows = Shape == VectorShape::Col ? BroadcastRows : 1;
532   static constexpr int kCols = Shape == VectorShape::Row ? BroadcastCols : 1;
533 };
534 
535 template <int BroadcastRows, int BroadcastCols, typename ScalarType,
536           VectorShape Shape>
537 struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
538                                 VectorDup<ScalarType, Shape>> {
539   static constexpr int kRows = 1;
540   static constexpr int kCols = 1;
541 };
542 
543 template <typename RegisterBlockType, typename SrcObjectType>
544 struct LoadForBroadcastingRegisterBlock {
545   using Shape =
546       LoadForBroadcastingShape<RegisterBlockType::kRows,
547                                RegisterBlockType::kCols, SrcObjectType>;
548   using ScalarType = typename RegisterBlockType::ScalarType;
549   using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
550 };
551 
552 template <typename RegisterBlockType, typename SrcObjectType>
553 struct LoadForBroadcastingImpl {
554   static_assert(std::is_same<SrcObjectType, void>::value,
555                 "This generic impl should never be hit");
556 };
557 
558 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
559           VectorShape Shape>
560 struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
561                                VectorMap<SrcScalarType, Shape>> {
562   using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
563   using SrcObjectType = VectorMap<SrcScalarType, Shape>;
564   using ResultBlockType =
565       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
566                                                 SrcObjectType>::Type;
567   static_assert(ResultBlockType::kRegisterLanes == 1,
568                 "This path is only for scalar values");
569   static ResultBlockType Run(const SrcObjectType& src, int pos) {
570     ResultBlockType result;
571     for (int c = 0; c < ResultBlockType::kCols; c++) {
572       for (int r = 0; r < ResultBlockType::kRows; r++) {
573         const int i = Shape == VectorShape::Col ? r : c;
574         result.buf.reg[r + c * ResultBlockType::kRows] = src(pos + i);
575       }
576     }
577     return result;
578   }
579 };
580 
581 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
582           VectorShape Shape>
583 struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
584                                VectorDup<SrcScalarType, Shape>> {
585   using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
586   using SrcObjectType = VectorDup<SrcScalarType, Shape>;
587   using ResultBlockType =
588       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
589                                                 SrcObjectType>::Type;
590   static_assert(ResultBlockType::kRegisterLanes == 1,
591                 "This path is only for scalar values");
592   static ResultBlockType Run(const SrcObjectType& src, int) {
593     ResultBlockType result;
594     for (int c = 0; c < ResultBlockType::kCols; c++) {
595       for (int r = 0; r < ResultBlockType::kRows; r++) {
596         result.buf.reg[r + c * ResultBlockType::kRows] = src(0);
597       }
598     }
599     return result;
600   }
601 };
602 
603 template <typename RegisterBlockType, typename SrcObjectType>
604 typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
605                                           SrcObjectType>::Type
606 LoadForBroadcasting(const SrcObjectType& src, int row, int col) {
607   return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(
608       src, row, col);
609 }
610 
611 template <typename RegisterBlockType, typename SrcObjectType>
612 typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
613                                           SrcObjectType>::Type
614 LoadForBroadcasting(const SrcObjectType& src, int pos) {
615   return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(src,
616                                                                         pos);
617 }
618 
619 template <int ConstantValue, typename RegisterBlockType>
620 struct AddConstantImpl {
621   static void Run(RegisterBlockType* block) {
622     using RegisterType = typename RegisterBlockType::RegisterType;
623     const RegisterType dup = Dup<RegisterType>(ConstantValue);
624     for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
625       block->buf.reg[i] = Add(block->buf.reg[i], dup);
626     }
627   }
628 };
629 
630 template <typename RegisterBlockType>
631 struct AddConstantImpl<0, RegisterBlockType> {
632   static void Run(RegisterBlockType*) {
633     // This is a no-op.
634   }
635 };
636 
637 template <int ConstantValue, typename RegisterBlockType>
638 void AddConstant(RegisterBlockType* block) {
639   AddConstantImpl<ConstantValue, RegisterBlockType>::Run(block);
640 }
641 
642 template <int N>
643 using RegBufferInt32 = RegisterBuffer<std::int32_t, N>;
644 template <int N>
645 using RegBufferInt16 = RegisterBuffer<std::int16_t, N>;
646 template <int N>
647 using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>;
648 template <int N>
649 using RegBufferInt8 = RegisterBuffer<std::int8_t, N>;
650 template <int R, int C>
651 using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>;
652 template <int R, int C>
653 using RegBlockInt16 = RegisterBlock<std::int16_t, R, C>;
654 template <int R, int C>
655 using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>;
656 template <int R, int C>
657 using RegBlockInt8 = RegisterBlock<std::int8_t, R, C>;
658 
659 }  // end namespace gemmlowp
660 
661 #if defined GEMMLOWP_NEON
662 #include "simd_wrappers_neon.h"
663 #elif defined GEMMLOWP_SSE4
664 #include "simd_wrappers_sse.h"
665 #elif defined GEMMLOWP_MSA
666 #include "simd_wrappers_msa.h"
667 #endif
668 
669 #endif  // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
670