xref: /aosp_15_r20/external/gemmlowp/internal/simd_wrappers_common_neon_sse.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // simd_wrappers_common_neon_sse.h: common SIMD (NEON and SSE) wrapper code
16 
17 #ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_
18 #define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_
19 
20 #include "simd_wrappers.h"
21 
22 namespace gemmlowp {
23 
24 template <typename SrcScalarType, int N>
25 struct LoadImpl<RegBlockInt32<4, N>,
26                 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
27   static RegBlockInt32<4, N> Run(
28       const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
29       int col) {
30     RegBlockInt32<4, N> result;
31     for (int i = 0; i < N; i++) {
32       result.buf.reg[i] = LoadInt32x4(src.data(row, col + i));
33     }
34     return result;
35   }
36 };
37 
38 template <typename SrcScalarType, int N>
39 struct LoadImpl<RegBlockInt32<8, N>,
40                 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
41   static RegBlockInt32<8, N> Run(
42       const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
43       int col) {
44     RegBlockInt32<8, N> result;
45     for (int i = 0; i < N; i++) {
46       result.buf.reg[2 * i + 0] = LoadInt32x4(src.data(row + 0, col + i));
47       result.buf.reg[2 * i + 1] = LoadInt32x4(src.data(row + 4, col + i));
48     }
49     return result;
50   }
51 };
52 
53 template <typename SrcScalarType>
54 struct LoadImpl<RegBlockInt32<1, 4>,
55                 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
56   static RegBlockInt32<1, 4> Run(
57       const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
58       int col) {
59     RegBlockInt32<1, 4> result;
60     std::int32_t buf[4];
61     for (int i = 0; i < 4; i++) {
62       buf[i] = src(row, col + i);
63     }
64     result.buf.reg[0] = LoadInt32x4(buf);
65     return result;
66   }
67 };
68 
69 template <typename SrcScalarType>
70 struct LoadImpl<RegBlockInt32<1, 8>,
71                 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
72   static RegBlockInt32<1, 8> Run(
73       const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
74       int col) {
75     RegBlockInt32<1, 8> result;
76     std::int32_t buf[8];
77     for (int i = 0; i < 8; i++) {
78       buf[i] = src(row, col + i);
79     }
80     result.buf.reg[0] = LoadInt32x4(buf);
81     result.buf.reg[1] = LoadInt32x4(buf + 4);
82     return result;
83   }
84 };
85 
86 template <typename SrcScalarType>
87 struct LoadImpl<RegBlockInt32<4, 1>,
88                 VectorMap<SrcScalarType, VectorShape::Col>> {
89   static RegBlockInt32<4, 1> Run(
90       const VectorMap<SrcScalarType, VectorShape::Col>& src, int pos) {
91     RegBlockInt32<4, 1> result;
92     result.buf.reg[0] = LoadInt32x4(src.data(pos));
93     return result;
94   }
95 };
96 
97 template <typename SrcScalarType>
98 struct LoadImpl<RegBlockInt32<4, 1>,
99                 VectorDup<SrcScalarType, VectorShape::Col>> {
100   static RegBlockInt32<4, 1> Run(
101       const VectorDup<SrcScalarType, VectorShape::Col>& src, int) {
102     RegBlockInt32<4, 1> result;
103     result.buf.reg[0] = LoadInt32x4(src(0));
104     return result;
105   }
106 };
107 
108 template <typename SrcScalarType, int N>
109 struct LoadForBroadcastingImpl<RegBlockInt32<4, N>,
110                                VectorMap<SrcScalarType, VectorShape::Col>> {
111   using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>;
112   using RegisterBlockType = RegBlockInt32<4, N>;
113   using ResultBlockType =
114       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
115                                                 SrcObjectType>::Type;
116 
117   static ResultBlockType Run(const SrcObjectType& src, int pos) {
118     ResultBlockType result;
119     static_assert(ResultBlockType::kRegisterCount == 1, "");
120     result.buf.reg[0] = LoadInt32x4(src.data(pos));
121     return result;
122   }
123 };
124 
125 template <typename SrcScalarType, int N>
126 struct LoadForBroadcastingImpl<RegBlockInt32<8, N>,
127                                VectorMap<SrcScalarType, VectorShape::Col>> {
128   using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>;
129   using RegisterBlockType = RegBlockInt32<8, N>;
130   using ResultBlockType =
131       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
132                                                 SrcObjectType>::Type;
133 
134   static ResultBlockType Run(const SrcObjectType& src, int pos) {
135     ResultBlockType result;
136     static_assert(ResultBlockType::kRegisterCount == 2, "");
137     result.buf.reg[0] = LoadInt32x4(src.data(pos));
138     result.buf.reg[1] = LoadInt32x4(src.data(pos + 4));
139     return result;
140   }
141 };
142 
143 template <typename SrcScalarType>
144 struct LoadForBroadcastingImpl<RegBlockInt32<4, 1>,
145                                VectorMap<SrcScalarType, VectorShape::Row>> {
146   using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>;
147   using RegisterBlockType = RegBlockInt32<4, 1>;
148   using ResultBlockType =
149       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
150                                                 SrcObjectType>::Type;
151 
152   static ResultBlockType Run(const SrcObjectType& src, int pos) {
153     ResultBlockType result;
154     result.buf.reg[0] = src(pos);
155     return result;
156   }
157 };
158 
159 template <typename SrcScalarType, int N>
160 struct LoadForBroadcastingImpl<RegBlockInt32<N, 4>,
161                                VectorMap<SrcScalarType, VectorShape::Row>> {
162   using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>;
163   using RegisterBlockType = RegBlockInt32<N, 4>;
164   using ResultBlockType =
165       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
166                                                 SrcObjectType>::Type;
167 
168   static ResultBlockType Run(const SrcObjectType& src, int pos) {
169     ResultBlockType result;
170     static_assert(ResultBlockType::kRegisterCount == 1, "");
171     result.buf.reg[0] = LoadInt32x4(src.data(pos));
172     return result;
173   }
174 };
175 
176 template <typename SrcScalarType, int N>
177 struct LoadForBroadcastingImpl<RegBlockInt32<N, 8>,
178                                VectorMap<SrcScalarType, VectorShape::Row>> {
179   using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>;
180   using RegisterBlockType = RegBlockInt32<N, 8>;
181   using ResultBlockType =
182       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
183                                                 SrcObjectType>::Type;
184 
185   static ResultBlockType Run(const SrcObjectType& src, int pos) {
186     ResultBlockType result;
187     static_assert(ResultBlockType::kRegisterCount == 2, "");
188     result.buf.reg[0] = LoadInt32x4(src.data(pos));
189     result.buf.reg[1] = LoadInt32x4(src.data(pos + 4));
190     return result;
191   }
192 };
193 
194 // 4x1 := 4x1 + 1x1
195 template <>
196 struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {
197   static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
198                                  const RegBlockInt32<1, 1>& rhs) {
199     RegBlockInt32<4, 1> result;
200     result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
201     return result;
202   }
203 };
204 
205 // 1x4 := 1x4 + 1x1
206 template <>
207 struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> {
208   static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
209                                  const RegBlockInt32<1, 1>& rhs) {
210     RegBlockInt32<1, 4> result;
211     result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
212     return result;
213   }
214 };
215 
216 // 4x1 := 4x1 + 4x1
217 template <>
218 struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> {
219   static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
220                                  const RegBlockInt32<4, 1>& rhs) {
221     RegBlockInt32<4, 1> result;
222     result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
223     return result;
224   }
225 };
226 
227 // 1x4 := 1x4 + 1x4
228 template <>
229 struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> {
230   static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
231                                  const RegBlockInt32<1, 4>& rhs) {
232     RegBlockInt32<1, 4> result;
233     result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
234     return result;
235   }
236 };
237 
238 // 4x4 := 4x4 + 1x4
239 template <>
240 struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> {
241   static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
242                                  const RegBlockInt32<1, 4>& rhs) {
243     RegBlockInt32<4, 4> result;
244     result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
245     result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
246     result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
247     result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
248     return result;
249   }
250 };
251 
252 // 4x4 := 4x4 + 4x1
253 template <>
254 struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> {
255   static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
256                                  const RegBlockInt32<4, 1>& rhs) {
257     RegBlockInt32<4, 4> result;
258     result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
259     result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[0]);
260     result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]);
261     result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[0]);
262     return result;
263   }
264 };
265 
266 // 8x1 := 8x1 + 1x1
267 template <>
268 struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> {
269   static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
270                                  const RegBlockInt32<1, 1>& rhs) {
271     RegBlockInt32<8, 1> result;
272     const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
273     for (int i = 0; i < 2; i++) {
274       result.buf.reg[i] = Add(lhs.buf.reg[i], p);
275     }
276     return result;
277   }
278 };
279 
280 // 8x1 := 8x1 + 8x1
281 template <>
282 struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> {
283   static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
284                                  const RegBlockInt32<8, 1>& rhs) {
285     RegBlockInt32<8, 1> result;
286     for (int i = 0; i < 2; i++) {
287       result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]);
288     }
289     return result;
290   }
291 };
292 
293 // 8x4 := 8x4 + 1x4
294 template <>
295 struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> {
296   static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
297                                  const RegBlockInt32<1, 4>& rhs) {
298     RegBlockInt32<8, 4> result;
299     result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
300     result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
301     result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
302     result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
303     result.buf.reg[4] = Add(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
304     result.buf.reg[5] = Add(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
305     result.buf.reg[6] = Add(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
306     result.buf.reg[7] = Add(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
307     return result;
308   }
309 };
310 
311 // 8x4 := 8x4 + 8x1
312 template <>
313 struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> {
314   static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
315                                  const RegBlockInt32<8, 1>& rhs) {
316     RegBlockInt32<8, 4> result;
317     result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
318     result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]);
319     result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]);
320     result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[1]);
321     result.buf.reg[4] = Add(lhs.buf.reg[4], rhs.buf.reg[0]);
322     result.buf.reg[5] = Add(lhs.buf.reg[5], rhs.buf.reg[1]);
323     result.buf.reg[6] = Add(lhs.buf.reg[6], rhs.buf.reg[0]);
324     result.buf.reg[7] = Add(lhs.buf.reg[7], rhs.buf.reg[1]);
325     return result;
326   }
327 };
328 
329 // 1x8 := 1x8 + 1x8
330 template <>
331 struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 8>> {
332   static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
333                                  const RegBlockInt32<1, 8>& rhs) {
334     RegBlockInt32<1, 8> result;
335     result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
336     result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]);
337     return result;
338   }
339 };
340 
341 // 1x8 := 1x8 + 1x1
342 template <>
343 struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> {
344   static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
345                                  const RegBlockInt32<1, 1>& rhs) {
346     RegBlockInt32<1, 8> result;
347     result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
348     result.buf.reg[1] = Add(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
349     return result;
350   }
351 };
352 
353 // 4x1 := 4x1 + 1x1
354 template <>
355 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>,
356                                                       RegBlockInt32<1, 1>> {
357   static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
358                                  const RegBlockInt32<1, 1>& rhs) {
359     RegBlockInt32<4, 1> result;
360     result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
361         lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
362     return result;
363   }
364 };
365 
366 // 1x4 := 1x4 + 1x1
367 template <>
368 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>,
369                                                       RegBlockInt32<1, 1>> {
370   static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
371                                  const RegBlockInt32<1, 1>& rhs) {
372     RegBlockInt32<1, 4> result;
373     result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
374         lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
375     return result;
376   }
377 };
378 
379 // 4x1 := 4x1 + 4x1
380 template <>
381 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>,
382                                                       RegBlockInt32<4, 1>> {
383   static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
384                                  const RegBlockInt32<4, 1>& rhs) {
385     RegBlockInt32<4, 1> result;
386     result.buf.reg[0] =
387         SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
388     return result;
389   }
390 };
391 
392 // 1x4 := 1x4 + 1x4
393 template <>
394 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>,
395                                                       RegBlockInt32<1, 4>> {
396   static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
397                                  const RegBlockInt32<1, 4>& rhs) {
398     RegBlockInt32<1, 4> result;
399     result.buf.reg[0] =
400         SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
401     return result;
402   }
403 };
404 
405 // 4x4 := 4x4 + 1x4
406 template <>
407 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>,
408                                                       RegBlockInt32<1, 4>> {
409   static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
410                                  const RegBlockInt32<1, 4>& rhs) {
411     RegBlockInt32<4, 4> result;
412     result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
413         lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
414     result.buf.reg[1] = SaturatingRoundingDoublingHighMul(
415         lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
416     result.buf.reg[2] = SaturatingRoundingDoublingHighMul(
417         lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
418     result.buf.reg[3] = SaturatingRoundingDoublingHighMul(
419         lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
420     return result;
421   }
422 };
423 
424 // 4x4 := 4x4 + 4x1
425 template <>
426 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>,
427                                                       RegBlockInt32<4, 1>> {
428   static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
429                                  const RegBlockInt32<4, 1>& rhs) {
430     RegBlockInt32<4, 4> result;
431     result.buf.reg[0] =
432         SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
433     result.buf.reg[1] =
434         SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[0]);
435     result.buf.reg[2] =
436         SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]);
437     result.buf.reg[3] =
438         SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[0]);
439     return result;
440   }
441 };
442 
443 // 8x1 := 8x1 + 1x1
444 template <>
445 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>,
446                                                       RegBlockInt32<1, 1>> {
447   static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
448                                  const RegBlockInt32<1, 1>& rhs) {
449     RegBlockInt32<8, 1> result;
450     const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
451     for (int i = 0; i < 2; i++) {
452       result.buf.reg[i] = SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], p);
453     }
454     return result;
455   }
456 };
457 
458 // 8x1 := 8x1 + 8x1
459 template <>
460 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>,
461                                                       RegBlockInt32<8, 1>> {
462   static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
463                                  const RegBlockInt32<8, 1>& rhs) {
464     RegBlockInt32<8, 1> result;
465     for (int i = 0; i < 2; i++) {
466       result.buf.reg[i] =
467           SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], rhs.buf.reg[i]);
468     }
469     return result;
470   }
471 };
472 
473 // 8x4 := 8x4 + 1x4
474 template <>
475 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>,
476                                                       RegBlockInt32<1, 4>> {
477   static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
478                                  const RegBlockInt32<1, 4>& rhs) {
479     RegBlockInt32<8, 4> result;
480     result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
481         lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
482     result.buf.reg[1] = SaturatingRoundingDoublingHighMul(
483         lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
484     result.buf.reg[2] = SaturatingRoundingDoublingHighMul(
485         lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
486     result.buf.reg[3] = SaturatingRoundingDoublingHighMul(
487         lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
488     result.buf.reg[4] = SaturatingRoundingDoublingHighMul(
489         lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
490     result.buf.reg[5] = SaturatingRoundingDoublingHighMul(
491         lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
492     result.buf.reg[6] = SaturatingRoundingDoublingHighMul(
493         lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
494     result.buf.reg[7] = SaturatingRoundingDoublingHighMul(
495         lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
496     return result;
497   }
498 };
499 
500 // 8x4 := 8x4 + 8x1
501 template <>
502 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>,
503                                                       RegBlockInt32<8, 1>> {
504   static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
505                                  const RegBlockInt32<8, 1>& rhs) {
506     RegBlockInt32<8, 4> result;
507     result.buf.reg[0] =
508         SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
509     result.buf.reg[1] =
510         SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]);
511     result.buf.reg[2] =
512         SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]);
513     result.buf.reg[3] =
514         SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[1]);
515     result.buf.reg[4] =
516         SaturatingRoundingDoublingHighMul(lhs.buf.reg[4], rhs.buf.reg[0]);
517     result.buf.reg[5] =
518         SaturatingRoundingDoublingHighMul(lhs.buf.reg[5], rhs.buf.reg[1]);
519     result.buf.reg[6] =
520         SaturatingRoundingDoublingHighMul(lhs.buf.reg[6], rhs.buf.reg[0]);
521     result.buf.reg[7] =
522         SaturatingRoundingDoublingHighMul(lhs.buf.reg[7], rhs.buf.reg[1]);
523     return result;
524   }
525 };
526 
527 // 1x8 := 1x8 + 1x8
528 template <>
529 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>,
530                                                       RegBlockInt32<1, 8>> {
531   static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
532                                  const RegBlockInt32<1, 8>& rhs) {
533     RegBlockInt32<1, 8> result;
534     result.buf.reg[0] =
535         SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
536     result.buf.reg[1] =
537         SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]);
538     return result;
539   }
540 };
541 
542 // 1x8 := 1x8 + 1x1
543 template <>
544 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>,
545                                                       RegBlockInt32<1, 1>> {
546   static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
547                                  const RegBlockInt32<1, 1>& rhs) {
548     RegBlockInt32<1, 8> result;
549     result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
550         lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
551     result.buf.reg[1] = SaturatingRoundingDoublingHighMul(
552         lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
553     return result;
554   }
555 };
556 
557 // 4x1 := 4x1 * 1x1
558 template <>
559 struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {
560   static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
561                                  const RegBlockInt32<1, 1>& rhs) {
562     RegBlockInt32<4, 1> result;
563     result.buf.reg[0] = Mul(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
564     return result;
565   }
566 };
567 
568 // 4x1 := 4x1 * 4x1
569 template <>
570 struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> {
571   static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
572                                  const RegBlockInt32<4, 1>& rhs) {
573     RegBlockInt32<4, 1> result;
574     result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
575     return result;
576   }
577 };
578 
579 // 1x4 := 1x4 * 1x4
580 template <>
581 struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> {
582   static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
583                                  const RegBlockInt32<1, 4>& rhs) {
584     RegBlockInt32<1, 4> result;
585     result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
586     return result;
587   }
588 };
589 
590 // 1x4 := 1x4 * 1x1
591 template <>
592 struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> {
593   static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
594                                  const RegBlockInt32<1, 1>& rhs) {
595     RegBlockInt32<1, 4> result;
596     result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
597     return result;
598   }
599 };
600 
601 // 4x4 := 4x4 * 1x4
602 template <>
603 struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> {
604   static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
605                                  const RegBlockInt32<1, 4>& rhs) {
606     RegBlockInt32<4, 4> result;
607     const Int32x4 p = rhs.buf.reg[0];
608     result.buf.reg[0] = MulByRhsLane<0>(lhs.buf.reg[0], p);
609     result.buf.reg[1] = MulByRhsLane<1>(lhs.buf.reg[1], p);
610     result.buf.reg[2] = MulByRhsLane<2>(lhs.buf.reg[2], p);
611     result.buf.reg[3] = MulByRhsLane<3>(lhs.buf.reg[3], p);
612     return result;
613   }
614 };
615 
616 // 4x4 := 4x4 * 4x1
617 template <>
618 struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> {
619   static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
620                                  const RegBlockInt32<4, 1>& rhs) {
621     RegBlockInt32<4, 4> result;
622     const Int32x4 p = rhs.buf.reg[0];
623     result.buf.reg[0] = Mul(lhs.buf.reg[0], p);
624     result.buf.reg[1] = Mul(lhs.buf.reg[1], p);
625     result.buf.reg[2] = Mul(lhs.buf.reg[2], p);
626     result.buf.reg[3] = Mul(lhs.buf.reg[3], p);
627     return result;
628   }
629 };
630 
631 // 8x1 := 8x1 * 1x1
632 template <>
633 struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> {
634   static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
635                                  const RegBlockInt32<1, 1>& rhs) {
636     RegBlockInt32<8, 1> result;
637     const std::int32_t p = rhs.buf.reg[0];
638     for (int i = 0; i < 2; i++) {
639       result.buf.reg[i] = Mul(lhs.buf.reg[i], p);
640     }
641     return result;
642   }
643 };
644 
645 // 8x1 := 8x1 * 8x1
646 template <>
647 struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> {
648   static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
649                                  const RegBlockInt32<8, 1>& rhs) {
650     RegBlockInt32<8, 1> result;
651     for (int i = 0; i < 2; i++) {
652       result.buf.reg[i] = Mul(lhs.buf.reg[i], rhs.buf.reg[i]);
653     }
654     return result;
655   }
656 };
657 
658 // 8x4 := 8x4 * 1x4
659 template <>
660 struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> {
661   static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
662                                  const RegBlockInt32<1, 4>& rhs) {
663     RegBlockInt32<8, 4> result;
664     const Int32x4 p = rhs.buf.reg[0];
665     for (int i = 0; i < 2; i++) {
666       result.buf.reg[i + 0] = MulByRhsLane<0>(lhs.buf.reg[i + 0], p);
667       result.buf.reg[i + 2] = MulByRhsLane<1>(lhs.buf.reg[i + 2], p);
668       result.buf.reg[i + 4] = MulByRhsLane<2>(lhs.buf.reg[i + 4], p);
669       result.buf.reg[i + 6] = MulByRhsLane<3>(lhs.buf.reg[i + 6], p);
670     }
671     return result;
672   }
673 };
674 
675 // 8x4 := 8x4 * 8x1
676 template <>
677 struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> {
678   static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
679                                  const RegBlockInt32<8, 1>& rhs) {
680     RegBlockInt32<8, 4> result;
681     const Int32x4 p[2]{rhs.buf.reg[0], rhs.buf.reg[1]};
682     for (int i = 0; i < 4; i++) {
683       for (int j = 0; j < 2; j++) {
684         const int k = j + 2 * i;
685         result.buf.reg[k] = Mul(lhs.buf.reg[k], p[j]);
686       }
687     }
688     return result;
689   }
690 };
691 
692 // Rx1 += Rx1 * 1x1
693 template <int Rows>
694 struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>,
695                            RegBlockInt32<Rows, 1>> {
696   static void Run(const RegBlockInt32<Rows, 1>& lhs,
697                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 1>* acc) {
698     const std::int32_t p = rhs.buf.reg[0];
699     for (int i = 0; i < RegBlockInt32<Rows, 1>::kRegisterCount; i++) {
700       MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]);
701     }
702   }
703 };
704 
705 // RxC += Rx1 * 1x1
706 template <int Rows, int Cols>
707 struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>,
708                            RegBlockInt32<Rows, Cols>> {
709   static void Run(const RegBlockInt32<Rows, 1>& lhs,
710                   const RegBlockInt32<1, 1>& rhs,
711                   RegBlockInt32<Rows, Cols>* acc) {
712     const std::int32_t p = rhs.buf.reg[0];
713     static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount;
714     for (int i = 0; i < kRegsPerCol; i++) {
715       const Int32x4 q = Mul(lhs.buf.reg[i], p);
716       for (int j = 0; j < Cols; j++) {
717         acc->buf.reg[i + j * kRegsPerCol] =
718             Add(acc->buf.reg[i + j * kRegsPerCol], q);
719       }
720     }
721   }
722 };
723 
724 // 1xC += 1xC * 1x1
725 template <int Cols>
726 struct BroadcastMulAddImpl<RegBlockInt32<1, Cols>, RegBlockInt32<1, 1>,
727                            RegBlockInt32<1, Cols>> {
728   static void Run(const RegBlockInt32<1, Cols>& lhs,
729                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) {
730     const std::int32_t p = rhs.buf.reg[0];
731     for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) {
732       MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]);
733     }
734   }
735 };
736 
737 // RxC += 1x1 * 1x1
738 template <int Rows, int Cols>
739 struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>,
740                            RegBlockInt32<Rows, Cols>> {
741   static void Run(const RegBlockInt32<1, 1>& lhs,
742                   const RegBlockInt32<1, 1>& rhs,
743                   RegBlockInt32<Rows, Cols>* acc) {
744     const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0]));
745     for (int i = 0; i < RegBlockInt32<Rows, Cols>::kRegisterCount; i++) {
746       acc->buf.reg[i] = Add(acc->buf.reg[i], p);
747     }
748   }
749 };
750 
751 // 1x1 += 1x1 * 1x1
752 template <>
753 struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>,
754                            RegBlockInt32<1, 1>> {
755   static void Run(const RegBlockInt32<1, 1>& lhs,
756                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 1>* acc) {
757     MulAdd(lhs.buf.reg[0], rhs.buf.reg[0], &acc->buf.reg[0]);
758   }
759 };
760 
761 // Rx4 += Rx1 * 1x4
762 template <int Rows>
763 struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 4>,
764                            RegBlockInt32<Rows, 4>> {
765   static void Run(const RegBlockInt32<Rows, 1>& lhs,
766                   const RegBlockInt32<1, 4>& rhs, RegBlockInt32<Rows, 4>* acc) {
767     const Int32x4 p = rhs.buf.reg[0];
768     static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount;
769     for (int i = 0; i < kRegsPerCol; i++) {
770       MulAddByRhsLane<0>(lhs.buf.reg[i], p, &acc->buf.reg[i + 0 * kRegsPerCol]);
771       MulAddByRhsLane<1>(lhs.buf.reg[i], p, &acc->buf.reg[i + 1 * kRegsPerCol]);
772       MulAddByRhsLane<2>(lhs.buf.reg[i], p, &acc->buf.reg[i + 2 * kRegsPerCol]);
773       MulAddByRhsLane<3>(lhs.buf.reg[i], p, &acc->buf.reg[i + 3 * kRegsPerCol]);
774     }
775   }
776 };
777 
778 // Rx4 += 1x4 * 1x1
779 template <int Rows>
780 struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>,
781                            RegBlockInt32<Rows, 4>> {
782   static void Run(const RegBlockInt32<1, 4>& lhs,
783                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 4>* acc) {
784     const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
785     Int32x4 q[4];
786     q[0] = DupLane<0>(p);
787     q[1] = DupLane<1>(p);
788     q[2] = DupLane<2>(p);
789     q[3] = DupLane<3>(p);
790     static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount;
791     for (int i = 0; i < kRegsPerCol; i++) {
792       for (int j = 0; j < 4; j++) {
793         acc->buf.reg[i + j * kRegsPerCol] =
794             Add(q[j], acc->buf.reg[i + j * kRegsPerCol]);
795       }
796     }
797   }
798 };
799 
800 // 1xC += 1x1 * 1x1
801 template <int Cols>
802 struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>,
803                            RegBlockInt32<1, Cols>> {
804   static void Run(const RegBlockInt32<1, 1>& lhs,
805                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) {
806     const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0]));
807     for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) {
808       acc->buf.reg[i] = Add(acc->buf.reg[i], p);
809     }
810   }
811 };
812 
813 // 1x4 += 1x4 * 1x1
814 template <>
815 struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>,
816                            RegBlockInt32<1, 4>> {
817   static void Run(const RegBlockInt32<1, 4>& lhs,
818                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 4>* acc) {
819     const std::int32_t p = rhs.buf.reg[0];
820     MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]);
821   }
822 };
823 
824 // 4xC += 4x1 * 1x1
825 template <int Cols>
826 struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>,
827                            RegBlockInt32<4, Cols>> {
828   static void Run(const RegBlockInt32<4, 1>& lhs,
829                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, Cols>* acc) {
830     const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
831     for (int i = 0; i < Cols; i++) {
832       acc->buf.reg[i] = Add(p, acc->buf.reg[i]);
833     }
834   }
835 };
836 
837 // 4x1 += 4x1 * 1x1
838 template <>
839 struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>,
840                            RegBlockInt32<4, 1>> {
841   static void Run(const RegBlockInt32<4, 1>& lhs,
842                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, 1>* acc) {
843     const std::int32_t p = rhs.buf.reg[0];
844     MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]);
845   }
846 };
847 
848 }  // namespace gemmlowp
849 
850 #endif  // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_
851