1 // Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // simd_wrappers.h: some inline functions wrapping SIMD intrinsics,
16 // extending the set of such functions from fixedpoint.h.
17
18 #ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
19 #define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
20
21 #include <algorithm>
22 #include <type_traits>
23 #include "../fixedpoint/fixedpoint.h"
24
25 namespace gemmlowp {
26
27 template <typename ScalarType, int ScalarCount>
28 struct RegisterType {
29 using Type = ScalarType;
30 };
31
Min(std::int32_t a,std::int32_t b)32 inline std::int32_t Min(std::int32_t a, std::int32_t b) {
33 return std::min(a, b);
34 }
35
Max(std::int32_t a,std::int32_t b)36 inline std::int32_t Max(std::int32_t a, std::int32_t b) {
37 return std::max(a, b);
38 }
39
MulAdd(std::int32_t lhs,std::int32_t rhs,std::int32_t * acc)40 inline void MulAdd(std::int32_t lhs, std::int32_t rhs, std::int32_t* acc) {
41 *acc += lhs * rhs;
42 }
43
44 template <typename tScalarType, int tScalarCount>
45 struct RegisterBuffer {
46 using ScalarType = tScalarType;
47 static constexpr int kScalarCount = tScalarCount;
48 using RegisterType = typename RegisterType<ScalarType, kScalarCount>::Type;
49 static_assert((kScalarCount & (kScalarCount - 1)) == 0,
50 "kScalarCount must be a power of two");
51 static_assert(sizeof(RegisterType) % sizeof(ScalarType) == 0, "");
52 static constexpr int kRegisterLanes =
53 sizeof(RegisterType) / sizeof(ScalarType);
54 static constexpr int kRegisterCount =
55 (kScalarCount * sizeof(ScalarType) + sizeof(RegisterType) - 1) /
56 sizeof(RegisterType);
57
58 RegisterType reg[kRegisterCount];
59 };
60
61 template <typename tScalarType, int tRows, int tCols>
62 struct RegisterBlock {
63 using ScalarType = tScalarType;
64 static constexpr int kRows = tRows;
65 static constexpr int kCols = tCols;
66 static constexpr int kScalarCount = kRows * kCols;
67 using BufferType = RegisterBuffer<ScalarType, kScalarCount>;
68 using RegisterType = typename BufferType::RegisterType;
69 static constexpr int kRegisterCount = BufferType::kRegisterCount;
70 static constexpr int kRegisterLanes = BufferType::kRegisterLanes;
71
72 BufferType buf;
73 };
74
75 template <typename RegisterBlockType>
76 struct RegisterBlockAddImpl {
RunRegisterBlockAddImpl77 static RegisterBlockType Run(const RegisterBlockType& lhs,
78 const RegisterBlockType& rhs) {
79 RegisterBlockType result;
80 for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
81 result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]);
82 }
83 return result;
84 }
85 };
86
87 template <typename RegisterBlockType>
RegisterBlockAdd(const RegisterBlockType & lhs,const RegisterBlockType & rhs)88 RegisterBlockType RegisterBlockAdd(const RegisterBlockType& lhs,
89 const RegisterBlockType& rhs) {
90 return RegisterBlockAddImpl<RegisterBlockType>::Run(lhs, rhs);
91 }
92
93 template <typename LhsType, typename RhsType>
94 struct ShouldFlipLhsRhs {
95 static constexpr bool kValue =
96 (LhsType::kScalarCount < RhsType::kScalarCount) ||
97 (LhsType::kScalarCount == RhsType::kScalarCount &&
98 (LhsType::kRows < RhsType::kRows));
99 };
100
101 template <typename LhsType, typename RhsType,
102 bool Flip = ShouldFlipLhsRhs<LhsType, RhsType>::kValue>
103 struct FlipLhsRhs {
104 using FlippedLhsType = LhsType;
105 using FlippedRhsType = RhsType;
FlippedLhsFlipLhsRhs106 static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
107 const RhsType& rhs) {
108 (void)rhs;
109 return lhs;
110 }
FlippedRhsFlipLhsRhs111 static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
112 const RhsType& rhs) {
113 (void)lhs;
114 return rhs;
115 }
116 };
117
118 template <typename LhsType, typename RhsType>
119 struct FlipLhsRhs<LhsType, RhsType, true> {
120 using FlippedLhsType = RhsType;
121 using FlippedRhsType = LhsType;
122 static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
123 const RhsType& rhs) {
124 (void)lhs;
125 return rhs;
126 }
127 static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
128 const RhsType& rhs) {
129 (void)rhs;
130 return lhs;
131 }
132 };
133
134 template <typename Lhs, typename Rhs>
135 struct BroadcastBinaryOpShape {
136 static constexpr int kRows =
137 Lhs::kRows > Rhs::kRows ? Lhs::kRows : Rhs::kRows;
138 static constexpr int kCols =
139 Lhs::kCols > Rhs::kCols ? Lhs::kCols : Rhs::kCols;
140 };
141
142 template <typename Lhs, typename Rhs>
143 struct BroadcastBinaryOpRegisterBlock {
144 using Shape = BroadcastBinaryOpShape<Lhs, Rhs>;
145 using ScalarType = typename Lhs::ScalarType;
146 using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
147 };
148
149 template <typename Lhs, typename Rhs>
150 struct BroadcastAddImpl {
151 using ResultBlockType =
152 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
153 static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
154 ResultBlockType result;
155 static constexpr int Rows = ResultBlockType::kRows;
156 static constexpr int Cols = ResultBlockType::kCols;
157 static constexpr int LhsRows = Lhs::kRows;
158 static constexpr int LhsCols = Lhs::kCols;
159 static constexpr int RhsRows = Rhs::kRows;
160 static constexpr int RhsCols = Rhs::kCols;
161
162 static_assert(LhsRows == Rows || LhsRows == 1, "");
163 static_assert(RhsRows == Rows || RhsRows == 1, "");
164 static_assert(LhsCols == Cols || LhsCols == 1, "");
165 static_assert(RhsCols == Cols || RhsCols == 1, "");
166 static_assert(ResultBlockType::kRegisterLanes == 1,
167 "This path is only for scalar values");
168 static_assert(Lhs::kRegisterLanes == 1,
169 "This path is only for scalar values");
170 static_assert(Rhs::kRegisterLanes == 1,
171 "This path is only for scalar values");
172
173 for (int c = 0; c < Cols; c++) {
174 const int lhs_c = LhsCols == Cols ? c : 0;
175 const int rhs_c = RhsCols == Cols ? c : 0;
176 for (int r = 0; r < Rows; r++) {
177 const int lhs_r = LhsRows == Rows ? r : 0;
178 const int rhs_r = RhsRows == Rows ? r : 0;
179 result.buf.reg[r + c * Rows] =
180 Add(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
181 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
182 }
183 }
184 return result;
185 }
186 };
187
188 template <typename Lhs, typename Rhs>
189 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastAdd(
190 const Lhs& lhs, const Rhs& rhs) {
191 using Flip = FlipLhsRhs<Lhs, Rhs>;
192 return BroadcastAddImpl<
193 typename Flip::FlippedLhsType,
194 typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
195 Flip::FlippedRhs(lhs, rhs));
196 }
197
198 template <typename Lhs, typename Rhs>
199 struct BroadcastShiftLeftImpl {
200 using ResultBlockType =
201 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
202 static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
203 ResultBlockType result;
204 static constexpr int Rows = ResultBlockType::kRows;
205 static constexpr int Cols = ResultBlockType::kCols;
206 static constexpr int LhsRows = Lhs::kRows;
207 static constexpr int LhsCols = Lhs::kCols;
208 static constexpr int RhsRows = Rhs::kRows;
209 static constexpr int RhsCols = Rhs::kCols;
210
211 static_assert(LhsRows == Rows || LhsRows == 1, "");
212 static_assert(RhsRows == Rows || RhsRows == 1, "");
213 static_assert(LhsCols == Cols || LhsCols == 1, "");
214 static_assert(RhsCols == Cols || RhsCols == 1, "");
215 static_assert(ResultBlockType::kRegisterLanes == 1,
216 "This path is only for scalar values");
217 static_assert(Lhs::kRegisterLanes == 1,
218 "This path is only for scalar values");
219 static_assert(Rhs::kRegisterLanes == 1,
220 "This path is only for scalar values");
221
222 for (int c = 0; c < Cols; c++) {
223 const int lhs_c = LhsCols == Cols ? c : 0;
224 const int rhs_c = RhsCols == Cols ? c : 0;
225 for (int r = 0; r < Rows; r++) {
226 const int lhs_r = LhsRows == Rows ? r : 0;
227 const int rhs_r = RhsRows == Rows ? r : 0;
228 result.buf.reg[r + c * Rows] =
229 ShiftLeft(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
230 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
231 }
232 }
233 return result;
234 }
235 };
236
237 template <typename Lhs, typename Rhs>
238 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastShiftLeft(
239 const Lhs& lhs, const Rhs& rhs) {
240 using Flip = FlipLhsRhs<Lhs, Rhs>;
241 return BroadcastShiftLeftImpl<
242 typename Flip::FlippedLhsType,
243 typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
244 Flip::FlippedRhs(lhs, rhs));
245 }
246
247 template <typename Lhs, typename Rhs>
248 struct BroadcastSaturatingRoundingDoublingHighMulImpl {
249 using ResultBlockType =
250 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
251 static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
252 ResultBlockType result;
253 static constexpr int Rows = ResultBlockType::kRows;
254 static constexpr int Cols = ResultBlockType::kCols;
255 static constexpr int LhsRows = Lhs::kRows;
256 static constexpr int LhsCols = Lhs::kCols;
257 static constexpr int RhsRows = Rhs::kRows;
258 static constexpr int RhsCols = Rhs::kCols;
259
260 static_assert(LhsRows == Rows || LhsRows == 1, "");
261 static_assert(RhsRows == Rows || RhsRows == 1, "");
262 static_assert(LhsCols == Cols || LhsCols == 1, "");
263 static_assert(RhsCols == Cols || RhsCols == 1, "");
264 static_assert(ResultBlockType::kRegisterLanes == 1,
265 "This path is only for scalar values");
266 static_assert(Lhs::kRegisterLanes == 1,
267 "This path is only for scalar values");
268 static_assert(Rhs::kRegisterLanes == 1,
269 "This path is only for scalar values");
270
271 for (int c = 0; c < Cols; c++) {
272 const int lhs_c = LhsCols == Cols ? c : 0;
273 const int rhs_c = RhsCols == Cols ? c : 0;
274 for (int r = 0; r < Rows; r++) {
275 const int lhs_r = LhsRows == Rows ? r : 0;
276 const int rhs_r = RhsRows == Rows ? r : 0;
277 result.buf.reg[r + c * Rows] = SaturatingRoundingDoublingHighMul(
278 lhs.buf.reg[lhs_r + lhs_c * LhsRows],
279 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
280 }
281 }
282 return result;
283 }
284 };
285
286 template <typename Lhs, typename Rhs>
287 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type
288 BroadcastSaturatingRoundingDoublingHighMul(const Lhs& lhs, const Rhs& rhs) {
289 using Flip = FlipLhsRhs<Lhs, Rhs>;
290 return BroadcastSaturatingRoundingDoublingHighMulImpl<
291 typename Flip::FlippedLhsType,
292 typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
293 Flip::FlippedRhs(lhs, rhs));
294 }
295
296 template <typename Lhs, typename Rhs>
297 struct BroadcastRoundingDivideByPOTImpl {
298 using ResultBlockType =
299 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
300 static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
301 ResultBlockType result;
302 static constexpr int Rows = ResultBlockType::kRows;
303 static constexpr int Cols = ResultBlockType::kCols;
304 static constexpr int LhsRows = Lhs::kRows;
305 static constexpr int LhsCols = Lhs::kCols;
306 static constexpr int RhsRows = Rhs::kRows;
307 static constexpr int RhsCols = Rhs::kCols;
308
309 static_assert(LhsRows == Rows || LhsRows == 1, "");
310 static_assert(RhsRows == Rows || RhsRows == 1, "");
311 static_assert(LhsCols == Cols || LhsCols == 1, "");
312 static_assert(RhsCols == Cols || RhsCols == 1, "");
313 static_assert(ResultBlockType::kRegisterLanes == 1,
314 "This path is only for scalar values");
315 static_assert(Lhs::kRegisterLanes == 1,
316 "This path is only for scalar values");
317 static_assert(Rhs::kRegisterLanes == 1,
318 "This path is only for scalar values");
319
320 for (int c = 0; c < Cols; c++) {
321 const int lhs_c = LhsCols == Cols ? c : 0;
322 const int rhs_c = RhsCols == Cols ? c : 0;
323 for (int r = 0; r < Rows; r++) {
324 const int lhs_r = LhsRows == Rows ? r : 0;
325 const int rhs_r = RhsRows == Rows ? r : 0;
326 result.buf.reg[r + c * Rows] =
327 RoundingDivideByPOT(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
328 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
329 }
330 }
331 return result;
332 }
333 };
334
335 template <typename Lhs, typename Rhs>
336 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type
337 BroadcastRoundingDivideByPOT(const Lhs& lhs, const Rhs& rhs) {
338 using Flip = FlipLhsRhs<Lhs, Rhs>;
339 return BroadcastRoundingDivideByPOTImpl<
340 typename Flip::FlippedLhsType,
341 typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
342 Flip::FlippedRhs(lhs, rhs));
343 }
344
345 template <typename Lhs, typename Rhs>
346 struct BroadcastMulImpl {
347 using ResultBlockType =
348 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
349 static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
350 ResultBlockType result;
351 static constexpr int Rows = ResultBlockType::kRows;
352 static constexpr int Cols = ResultBlockType::kCols;
353 static constexpr int LhsRows = Lhs::kRows;
354 static constexpr int LhsCols = Lhs::kCols;
355 static constexpr int RhsRows = Rhs::kRows;
356 static constexpr int RhsCols = Rhs::kCols;
357 static_assert(ResultBlockType::kRegisterLanes == 1,
358 "This path is only for scalar values");
359 static_assert(Lhs::kRegisterLanes == 1,
360 "This path is only for scalar values");
361 static_assert(Rhs::kRegisterLanes == 1,
362 "This path is only for scalar values");
363
364 static_assert(LhsRows == Rows || LhsRows == 1, "");
365 static_assert(RhsRows == Rows || RhsRows == 1, "");
366 static_assert(LhsCols == Cols || LhsCols == 1, "");
367 static_assert(RhsCols == Cols || RhsCols == 1, "");
368 for (int c = 0; c < Cols; c++) {
369 const int lhs_c = LhsCols == Cols ? c : 0;
370 const int rhs_c = RhsCols == Cols ? c : 0;
371 for (int r = 0; r < Rows; r++) {
372 const int lhs_r = LhsRows == Rows ? r : 0;
373 const int rhs_r = RhsRows == Rows ? r : 0;
374 result.buf.reg[r + c * Rows] =
375 Mul(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
376 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
377 }
378 }
379 return result;
380 }
381 };
382
383 template <typename Lhs, typename Rhs>
384 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastMul(
385 const Lhs& lhs, const Rhs& rhs) {
386 using Flip = FlipLhsRhs<Lhs, Rhs>;
387 return BroadcastMulImpl<
388 typename Flip::FlippedLhsType,
389 typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
390 Flip::FlippedRhs(lhs, rhs));
391 }
392
393 template <typename Lhs, typename Rhs, typename Acc>
394 struct BroadcastMulAddImpl {
395 static void Run(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
396 static constexpr int Rows = Acc::kRows;
397 static constexpr int Cols = Acc::kCols;
398 static constexpr int LhsRows = Lhs::kRows;
399 static constexpr int LhsCols = Lhs::kCols;
400 static constexpr int RhsRows = Rhs::kRows;
401 static constexpr int RhsCols = Rhs::kCols;
402 static_assert(Acc::kRegisterLanes == 1,
403 "This path is only for scalar values");
404 static_assert(Lhs::kRegisterLanes == 1,
405 "This path is only for scalar values");
406 static_assert(Rhs::kRegisterLanes == 1,
407 "This path is only for scalar values");
408
409 static_assert(LhsRows == Rows || LhsRows == 1, "");
410 static_assert(RhsRows == Rows || RhsRows == 1, "");
411 static_assert(LhsCols == Cols || LhsCols == 1, "");
412 static_assert(RhsCols == Cols || RhsCols == 1, "");
413 for (int c = 0; c < Cols; c++) {
414 const int lhs_c = LhsCols == Cols ? c : 0;
415 const int rhs_c = RhsCols == Cols ? c : 0;
416 for (int r = 0; r < Rows; r++) {
417 const int lhs_r = LhsRows == Rows ? r : 0;
418 const int rhs_r = RhsRows == Rows ? r : 0;
419 MulAdd(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
420 rhs.buf.reg[rhs_r + rhs_c * RhsRows],
421 &acc->buf.reg[r + c * Rows]);
422 }
423 }
424 }
425 };
426
427 template <typename Lhs, typename Rhs, typename Acc>
428 void BroadcastMulAdd(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
429 using Flip = FlipLhsRhs<Lhs, Rhs>;
430 BroadcastMulAddImpl<typename Flip::FlippedLhsType,
431 typename Flip::FlippedRhsType,
432 Acc>::Run(Flip::FlippedLhs(lhs, rhs),
433 Flip::FlippedRhs(lhs, rhs), acc);
434 }
435
436 template <typename RegisterBlockType, typename SrcObjectType>
437 struct LoadImpl {
438 static_assert(std::is_same<SrcObjectType, void>::value,
439 "This generic impl should never be hit");
440 };
441
442 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType>
443 struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
444 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
445 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
446 using SrcObjectType = MatrixMap<SrcScalarType, MapOrder::ColMajor>;
447 static RegisterBlockType Run(const SrcObjectType& src, int row, int col) {
448 RegisterBlockType result;
449 int i = 0;
450 for (int c = 0; c < Cols; c++) {
451 const ScalarType* src_ptr = src.data(row, col + c);
452 for (int r = 0; r < Rows; r++) {
453 result.buf.reg[i++] = *src_ptr++;
454 }
455 }
456 return result;
457 }
458 };
459
460 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
461 VectorShape Shape>
462 struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
463 VectorMap<SrcScalarType, Shape>> {
464 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
465 using SrcObjectType = VectorMap<SrcScalarType, Shape>;
466 static RegisterBlockType Run(const SrcObjectType& src, int pos) {
467 static_assert(Shape == VectorShape::Col || Rows == 1, "");
468 static_assert(Shape == VectorShape::Row || Cols == 1, "");
469 RegisterBlockType result;
470 for (int i = 0; i < Rows * Cols; i++) {
471 result.buf.reg[i] = src(pos + i);
472 }
473 return result;
474 }
475 };
476
477 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
478 VectorShape Shape>
479 struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
480 VectorDup<SrcScalarType, Shape>> {
481 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
482 using SrcObjectType = VectorDup<SrcScalarType, Shape>;
483 static RegisterBlockType Run(const SrcObjectType& src, int) {
484 static_assert(Shape == VectorShape::Col || Rows == 1, "");
485 static_assert(Shape == VectorShape::Row || Cols == 1, "");
486 RegisterBlockType result;
487 for (int i = 0; i < Rows * Cols; i++) {
488 result.buf.reg[i] = src(0);
489 }
490 return result;
491 }
492 };
493
494 template <typename RegisterBlockType, typename SrcObjectType>
495 RegisterBlockType Load(const SrcObjectType& src, int row, int col) {
496 return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, row, col);
497 }
498
499 template <typename RegisterBlockType, typename SrcObjectType>
500 RegisterBlockType Load(const SrcObjectType& src, int pos) {
501 return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, pos);
502 }
503
504 template <typename RegisterBlockType>
505 struct LoadContiguousImpl {
506 using ScalarType = typename RegisterBlockType::ScalarType;
507 static_assert(RegisterBlockType::kRegisterLanes == 1,
508 "This path is only for scalar values");
509 static RegisterBlockType Run(const ScalarType* src) {
510 RegisterBlockType result;
511 for (int i = 0; i < RegisterBlockType::kScalarCount; i++) {
512 result.buf.reg[i] = src[i];
513 }
514 return result;
515 }
516 };
517
518 template <typename RegisterBlockType>
519 RegisterBlockType LoadContiguous(
520 const typename RegisterBlockType::ScalarType* src) {
521 return LoadContiguousImpl<RegisterBlockType>::Run(src);
522 }
523
524 template <int BroadcastRows, int BroadcastCols, typename SrcObjectType>
525 struct LoadForBroadcastingShape {};
526
527 template <int BroadcastRows, int BroadcastCols, typename ScalarType,
528 VectorShape Shape>
529 struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
530 VectorMap<ScalarType, Shape>> {
531 static constexpr int kRows = Shape == VectorShape::Col ? BroadcastRows : 1;
532 static constexpr int kCols = Shape == VectorShape::Row ? BroadcastCols : 1;
533 };
534
535 template <int BroadcastRows, int BroadcastCols, typename ScalarType,
536 VectorShape Shape>
537 struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
538 VectorDup<ScalarType, Shape>> {
539 static constexpr int kRows = 1;
540 static constexpr int kCols = 1;
541 };
542
543 template <typename RegisterBlockType, typename SrcObjectType>
544 struct LoadForBroadcastingRegisterBlock {
545 using Shape =
546 LoadForBroadcastingShape<RegisterBlockType::kRows,
547 RegisterBlockType::kCols, SrcObjectType>;
548 using ScalarType = typename RegisterBlockType::ScalarType;
549 using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
550 };
551
552 template <typename RegisterBlockType, typename SrcObjectType>
553 struct LoadForBroadcastingImpl {
554 static_assert(std::is_same<SrcObjectType, void>::value,
555 "This generic impl should never be hit");
556 };
557
558 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
559 VectorShape Shape>
560 struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
561 VectorMap<SrcScalarType, Shape>> {
562 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
563 using SrcObjectType = VectorMap<SrcScalarType, Shape>;
564 using ResultBlockType =
565 typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
566 SrcObjectType>::Type;
567 static_assert(ResultBlockType::kRegisterLanes == 1,
568 "This path is only for scalar values");
569 static ResultBlockType Run(const SrcObjectType& src, int pos) {
570 ResultBlockType result;
571 for (int c = 0; c < ResultBlockType::kCols; c++) {
572 for (int r = 0; r < ResultBlockType::kRows; r++) {
573 const int i = Shape == VectorShape::Col ? r : c;
574 result.buf.reg[r + c * ResultBlockType::kRows] = src(pos + i);
575 }
576 }
577 return result;
578 }
579 };
580
581 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
582 VectorShape Shape>
583 struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
584 VectorDup<SrcScalarType, Shape>> {
585 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
586 using SrcObjectType = VectorDup<SrcScalarType, Shape>;
587 using ResultBlockType =
588 typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
589 SrcObjectType>::Type;
590 static_assert(ResultBlockType::kRegisterLanes == 1,
591 "This path is only for scalar values");
592 static ResultBlockType Run(const SrcObjectType& src, int) {
593 ResultBlockType result;
594 for (int c = 0; c < ResultBlockType::kCols; c++) {
595 for (int r = 0; r < ResultBlockType::kRows; r++) {
596 result.buf.reg[r + c * ResultBlockType::kRows] = src(0);
597 }
598 }
599 return result;
600 }
601 };
602
603 template <typename RegisterBlockType, typename SrcObjectType>
604 typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
605 SrcObjectType>::Type
606 LoadForBroadcasting(const SrcObjectType& src, int row, int col) {
607 return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(
608 src, row, col);
609 }
610
611 template <typename RegisterBlockType, typename SrcObjectType>
612 typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
613 SrcObjectType>::Type
614 LoadForBroadcasting(const SrcObjectType& src, int pos) {
615 return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(src,
616 pos);
617 }
618
619 template <int ConstantValue, typename RegisterBlockType>
620 struct AddConstantImpl {
621 static void Run(RegisterBlockType* block) {
622 using RegisterType = typename RegisterBlockType::RegisterType;
623 const RegisterType dup = Dup<RegisterType>(ConstantValue);
624 for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
625 block->buf.reg[i] = Add(block->buf.reg[i], dup);
626 }
627 }
628 };
629
630 template <typename RegisterBlockType>
631 struct AddConstantImpl<0, RegisterBlockType> {
632 static void Run(RegisterBlockType*) {
633 // This is a no-op.
634 }
635 };
636
637 template <int ConstantValue, typename RegisterBlockType>
638 void AddConstant(RegisterBlockType* block) {
639 AddConstantImpl<ConstantValue, RegisterBlockType>::Run(block);
640 }
641
642 template <int N>
643 using RegBufferInt32 = RegisterBuffer<std::int32_t, N>;
644 template <int N>
645 using RegBufferInt16 = RegisterBuffer<std::int16_t, N>;
646 template <int N>
647 using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>;
648 template <int N>
649 using RegBufferInt8 = RegisterBuffer<std::int8_t, N>;
650 template <int R, int C>
651 using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>;
652 template <int R, int C>
653 using RegBlockInt16 = RegisterBlock<std::int16_t, R, C>;
654 template <int R, int C>
655 using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>;
656 template <int R, int C>
657 using RegBlockInt8 = RegisterBlock<std::int8_t, R, C>;
658
659 } // end namespace gemmlowp
660
661 #if defined GEMMLOWP_NEON
662 #include "simd_wrappers_neon.h"
663 #elif defined GEMMLOWP_SSE4
664 #include "simd_wrappers_sse.h"
665 #elif defined GEMMLOWP_MSA
666 #include "simd_wrappers_msa.h"
667 #endif
668
669 #endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
670