1*5f39d1b3SJooyung Han // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han // http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han
15*5f39d1b3SJooyung Han // pack_neon.h: optimized NEON specializations of the templates in pack.h.
16*5f39d1b3SJooyung Han
17*5f39d1b3SJooyung Han #ifndef GEMMLOWP_INTERNAL_PACK_NEON_H_
18*5f39d1b3SJooyung Han #define GEMMLOWP_INTERNAL_PACK_NEON_H_
19*5f39d1b3SJooyung Han
20*5f39d1b3SJooyung Han #include "pack.h"
21*5f39d1b3SJooyung Han
22*5f39d1b3SJooyung Han #include <arm_neon.h>
23*5f39d1b3SJooyung Han
24*5f39d1b3SJooyung Han namespace gemmlowp {
25*5f39d1b3SJooyung Han
26*5f39d1b3SJooyung Han typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
27*5f39d1b3SJooyung Han WidthMajorUint8SideMap;
28*5f39d1b3SJooyung Han
29*5f39d1b3SJooyung Han typedef SideMap<const std::int8_t, SideMapOrder::WidthMajor>
30*5f39d1b3SJooyung Han WidthMajorInt8SideMap;
31*5f39d1b3SJooyung Han
32*5f39d1b3SJooyung Han template <int Cells>
33*5f39d1b3SJooyung Han using DepthMajorSideFormatNCells4x2 = KernelSideFormat<CellFormat<4, 2>, Cells>;
34*5f39d1b3SJooyung Han
35*5f39d1b3SJooyung Han template <int Cells>
36*5f39d1b3SJooyung Han class PackingRegisterBlock<
37*5f39d1b3SJooyung Han WidthMajorUint8SideMap,
38*5f39d1b3SJooyung Han PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>>
39*5f39d1b3SJooyung Han : public PackingRegisterBlockBase<
40*5f39d1b3SJooyung Han WidthMajorUint8SideMap,
41*5f39d1b3SJooyung Han PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>> {
42*5f39d1b3SJooyung Han public:
43*5f39d1b3SJooyung Han typedef DepthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
44*5f39d1b3SJooyung Han typedef typename KernelSideFormat::Cell CellFormat;
45*5f39d1b3SJooyung Han static const int kCells = KernelSideFormat::kCells;
46*5f39d1b3SJooyung Han static const int kCellWidth = CellFormat::kWidth;
47*5f39d1b3SJooyung Han static const int kKernelWidth = CellFormat::kWidth * kCells;
48*5f39d1b3SJooyung Han static const int kCellDepth = CellFormat::kDepth;
49*5f39d1b3SJooyung Han static const int kCellSize = CellFormat::kSize;
50*5f39d1b3SJooyung Han
Pack(PackedSideBlock<KernelSideFormat> * dst,int start_width)51*5f39d1b3SJooyung Han void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
52*5f39d1b3SJooyung Han std::uint8_t* dst_ptr = dst->current_data();
53*5f39d1b3SJooyung Han const std::uint8_t* const src_ptr = this->complete_src_.data();
54*5f39d1b3SJooyung Han const int stride = this->complete_src_.stride();
55*5f39d1b3SJooyung Han // Load source WidthMajor data
56*5f39d1b3SJooyung Han uint8x16_t src_lines[4 * kCells];
57*5f39d1b3SJooyung Han for (int i = 0; i < 4 * kCells; i++) {
58*5f39d1b3SJooyung Han src_lines[i] = vld1q_u8(src_ptr + i * stride);
59*5f39d1b3SJooyung Han }
60*5f39d1b3SJooyung Han // Reorder the data within registers to make DepthMajor 4x2 cells
61*5f39d1b3SJooyung Han uint8x16x2_t src_lines_intertwined_2x[2 * kCells];
62*5f39d1b3SJooyung Han for (int i = 0; i < kCells; i++) {
63*5f39d1b3SJooyung Han src_lines_intertwined_2x[2 * i] =
64*5f39d1b3SJooyung Han vzipq_u8(src_lines[4 * i], src_lines[4 * i + 2]);
65*5f39d1b3SJooyung Han src_lines_intertwined_2x[2 * i + 1] =
66*5f39d1b3SJooyung Han vzipq_u8(src_lines[4 * i + 1], src_lines[4 * i + 3]);
67*5f39d1b3SJooyung Han }
68*5f39d1b3SJooyung Han uint8x16x2_t src_lines_intertwined_4x[2 * kCells];
69*5f39d1b3SJooyung Han for (int i = 0; i < kCells; i++) {
70*5f39d1b3SJooyung Han src_lines_intertwined_4x[2 * i] =
71*5f39d1b3SJooyung Han vzipq_u8(src_lines_intertwined_2x[2 * i].val[0],
72*5f39d1b3SJooyung Han src_lines_intertwined_2x[2 * i + 1].val[0]);
73*5f39d1b3SJooyung Han src_lines_intertwined_4x[2 * i + 1] =
74*5f39d1b3SJooyung Han vzipq_u8(src_lines_intertwined_2x[2 * i].val[1],
75*5f39d1b3SJooyung Han src_lines_intertwined_2x[2 * i + 1].val[1]);
76*5f39d1b3SJooyung Han }
77*5f39d1b3SJooyung Han // Store the resulting DepthMajor 4x2 cells in the destination packed block
78*5f39d1b3SJooyung Han for (int outer = 0; outer < 2; outer++) {
79*5f39d1b3SJooyung Han for (int inner = 0; inner < 2; inner++) {
80*5f39d1b3SJooyung Han for (int cell = 0; cell < kCells; cell++) {
81*5f39d1b3SJooyung Han uint8x8_t value = vget_low_u8(
82*5f39d1b3SJooyung Han src_lines_intertwined_4x[2 * cell + outer].val[inner]);
83*5f39d1b3SJooyung Han vst1_u8(dst_ptr, value);
84*5f39d1b3SJooyung Han dst_ptr += 8;
85*5f39d1b3SJooyung Han }
86*5f39d1b3SJooyung Han for (int cell = 0; cell < kCells; cell++) {
87*5f39d1b3SJooyung Han uint8x8_t value = vget_high_u8(
88*5f39d1b3SJooyung Han src_lines_intertwined_4x[2 * cell + outer].val[inner]);
89*5f39d1b3SJooyung Han vst1_u8(dst_ptr, value);
90*5f39d1b3SJooyung Han dst_ptr += 8;
91*5f39d1b3SJooyung Han }
92*5f39d1b3SJooyung Han }
93*5f39d1b3SJooyung Han }
94*5f39d1b3SJooyung Han // Compute sums across the depth dimension
95*5f39d1b3SJooyung Han uint16x8_t sums_of_2_cells[kCells][4];
96*5f39d1b3SJooyung Han for (int outer = 0; outer < 2; outer++) {
97*5f39d1b3SJooyung Han for (int inner = 0; inner < 2; inner++) {
98*5f39d1b3SJooyung Han int i = 2 * outer + inner;
99*5f39d1b3SJooyung Han for (int cell = 0; cell < kCells; cell++) {
100*5f39d1b3SJooyung Han sums_of_2_cells[cell][i] = vaddl_u8(
101*5f39d1b3SJooyung Han vget_low_u8(
102*5f39d1b3SJooyung Han src_lines_intertwined_4x[2 * cell + outer].val[inner]),
103*5f39d1b3SJooyung Han vget_high_u8(
104*5f39d1b3SJooyung Han src_lines_intertwined_4x[2 * cell + outer].val[inner]));
105*5f39d1b3SJooyung Han }
106*5f39d1b3SJooyung Han }
107*5f39d1b3SJooyung Han }
108*5f39d1b3SJooyung Han int32x4_t sums_of_4_cells[kCells][4];
109*5f39d1b3SJooyung Han for (int i = 0; i < 4; i++) {
110*5f39d1b3SJooyung Han for (int cell = 0; cell < kCells; cell++) {
111*5f39d1b3SJooyung Han sums_of_4_cells[cell][i] = vreinterpretq_s32_u32(
112*5f39d1b3SJooyung Han vaddl_u16(vget_low_u16(sums_of_2_cells[cell][i]),
113*5f39d1b3SJooyung Han vget_high_u16(sums_of_2_cells[cell][i])));
114*5f39d1b3SJooyung Han }
115*5f39d1b3SJooyung Han }
116*5f39d1b3SJooyung Han // Update the sums_of_each_slice vector
117*5f39d1b3SJooyung Han for (int cell = 0; cell < kCells; cell++) {
118*5f39d1b3SJooyung Han int32x4_t s01 =
119*5f39d1b3SJooyung Han vaddq_s32(sums_of_4_cells[cell][0], sums_of_4_cells[cell][1]);
120*5f39d1b3SJooyung Han int32x4_t s23 =
121*5f39d1b3SJooyung Han vaddq_s32(sums_of_4_cells[cell][2], sums_of_4_cells[cell][3]);
122*5f39d1b3SJooyung Han int32x4_t s = vaddq_s32(s01, s23);
123*5f39d1b3SJooyung Han std::int32_t* sums_of_each_slice_ptr =
124*5f39d1b3SJooyung Han dst->sums_of_each_slice() + start_width + 4 * cell;
125*5f39d1b3SJooyung Han vst1q_s32(sums_of_each_slice_ptr,
126*5f39d1b3SJooyung Han vaddq_s32(s, vld1q_s32(sums_of_each_slice_ptr)));
127*5f39d1b3SJooyung Han }
128*5f39d1b3SJooyung Han dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
129*5f39d1b3SJooyung Han }
130*5f39d1b3SJooyung Han };
131*5f39d1b3SJooyung Han
132*5f39d1b3SJooyung Han template <int Cells>
133*5f39d1b3SJooyung Han using WidthMajorSideFormatNCells4x2 =
134*5f39d1b3SJooyung Han KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>;
135*5f39d1b3SJooyung Han
136*5f39d1b3SJooyung Han template <int Cells>
137*5f39d1b3SJooyung Han class PackingRegisterBlock<
138*5f39d1b3SJooyung Han WidthMajorUint8SideMap,
139*5f39d1b3SJooyung Han PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>>
140*5f39d1b3SJooyung Han : public PackingRegisterBlockBase<
141*5f39d1b3SJooyung Han WidthMajorUint8SideMap,
142*5f39d1b3SJooyung Han PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> {
143*5f39d1b3SJooyung Han public:
144*5f39d1b3SJooyung Han typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
145*5f39d1b3SJooyung Han typedef typename KernelSideFormat::Cell CellFormat;
146*5f39d1b3SJooyung Han static const int kCells = KernelSideFormat::kCells;
147*5f39d1b3SJooyung Han static const int kCellWidth = CellFormat::kWidth;
148*5f39d1b3SJooyung Han static const int kKernelWidth = CellFormat::kWidth * kCells;
149*5f39d1b3SJooyung Han static const int kCellDepth = CellFormat::kDepth;
150*5f39d1b3SJooyung Han static const int kCellSize = CellFormat::kSize;
151*5f39d1b3SJooyung Han
Pack(PackedSideBlock<KernelSideFormat> * dst,int start_width)152*5f39d1b3SJooyung Han void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
153*5f39d1b3SJooyung Han std::uint8_t* dst_ptr = dst->current_data();
154*5f39d1b3SJooyung Han const std::uint8_t* src_ptr = this->complete_src_.data();
155*5f39d1b3SJooyung Han const int stride = this->complete_src_.stride();
156*5f39d1b3SJooyung Han // Load source WidthMajor data
157*5f39d1b3SJooyung Han uint16x8_t src_lines[kCells * 4];
158*5f39d1b3SJooyung Han for (int i = 0; i < kCells; i++) {
159*5f39d1b3SJooyung Han // This packing path is used with our current
160*5f39d1b3SJooyung Han // less-than-8-bit kernel, and the partial unrolling of this loop
161*5f39d1b3SJooyung Han // results in substantially faster code (thanks to better
162*5f39d1b3SJooyung Han // register allocation) on Nexus 5.
163*5f39d1b3SJooyung Han
164*5f39d1b3SJooyung Han #define GEMMLOWP_UNROLLED_LOOP_ITER(k) \
165*5f39d1b3SJooyung Han src_lines[4 * i + k] = vreinterpretq_u16_u8(vld1q_u8(src_ptr)); \
166*5f39d1b3SJooyung Han src_ptr += stride;
167*5f39d1b3SJooyung Han
168*5f39d1b3SJooyung Han GEMMLOWP_UNROLLED_LOOP_ITER(0)
169*5f39d1b3SJooyung Han GEMMLOWP_UNROLLED_LOOP_ITER(1)
170*5f39d1b3SJooyung Han GEMMLOWP_UNROLLED_LOOP_ITER(2)
171*5f39d1b3SJooyung Han GEMMLOWP_UNROLLED_LOOP_ITER(3)
172*5f39d1b3SJooyung Han
173*5f39d1b3SJooyung Han #undef GEMMLOWP_UNROLLED_LOOP_ITER
174*5f39d1b3SJooyung Han }
175*5f39d1b3SJooyung Han // Reorder the data within registers to make WidthMajor 4x2 cells
176*5f39d1b3SJooyung Han uint16x8x2_t src_lines_intertwined_2x[2 * kCells];
177*5f39d1b3SJooyung Han for (int i = 0; i < kCells; i++) {
178*5f39d1b3SJooyung Han src_lines_intertwined_2x[2 * i] =
179*5f39d1b3SJooyung Han vzipq_u16(src_lines[4 * i], src_lines[4 * i + 2]);
180*5f39d1b3SJooyung Han src_lines_intertwined_2x[2 * i + 1] =
181*5f39d1b3SJooyung Han vzipq_u16(src_lines[4 * i + 1], src_lines[4 * i + 3]);
182*5f39d1b3SJooyung Han }
183*5f39d1b3SJooyung Han uint16x8x2_t src_lines_intertwined_4x[2 * kCells];
184*5f39d1b3SJooyung Han for (int i = 0; i < kCells; i++) {
185*5f39d1b3SJooyung Han src_lines_intertwined_4x[2 * i] =
186*5f39d1b3SJooyung Han vzipq_u16(src_lines_intertwined_2x[2 * i].val[0],
187*5f39d1b3SJooyung Han src_lines_intertwined_2x[2 * i + 1].val[0]);
188*5f39d1b3SJooyung Han src_lines_intertwined_4x[2 * i + 1] =
189*5f39d1b3SJooyung Han vzipq_u16(src_lines_intertwined_2x[2 * i].val[1],
190*5f39d1b3SJooyung Han src_lines_intertwined_2x[2 * i + 1].val[1]);
191*5f39d1b3SJooyung Han }
192*5f39d1b3SJooyung Han // Store the resulting WidthMajor 4x2 cells in the destination packed block
193*5f39d1b3SJooyung Han for (int outer = 0; outer < 2; outer++) {
194*5f39d1b3SJooyung Han for (int inner = 0; inner < 2; inner++) {
195*5f39d1b3SJooyung Han for (int cell = 0; cell < kCells; cell++) {
196*5f39d1b3SJooyung Han uint8x8_t value = vreinterpret_u8_u16(vget_low_u16(
197*5f39d1b3SJooyung Han src_lines_intertwined_4x[2 * cell + outer].val[inner]));
198*5f39d1b3SJooyung Han vst1_u8(dst_ptr, value);
199*5f39d1b3SJooyung Han dst_ptr += 8;
200*5f39d1b3SJooyung Han }
201*5f39d1b3SJooyung Han for (int cell = 0; cell < kCells; cell++) {
202*5f39d1b3SJooyung Han uint8x8_t value = vreinterpret_u8_u16(vget_high_u16(
203*5f39d1b3SJooyung Han src_lines_intertwined_4x[2 * cell + outer].val[inner]));
204*5f39d1b3SJooyung Han vst1_u8(dst_ptr, value);
205*5f39d1b3SJooyung Han dst_ptr += 8;
206*5f39d1b3SJooyung Han }
207*5f39d1b3SJooyung Han }
208*5f39d1b3SJooyung Han }
209*5f39d1b3SJooyung Han // Compute sums across the depth dimension
210*5f39d1b3SJooyung Han uint16x8_t sums_of_2[kCells][4];
211*5f39d1b3SJooyung Han for (int outer = 0; outer < 2; outer++) {
212*5f39d1b3SJooyung Han for (int inner = 0; inner < 2; inner++) {
213*5f39d1b3SJooyung Han int i = 2 * outer + inner;
214*5f39d1b3SJooyung Han for (int cell = 0; cell < kCells; cell++) {
215*5f39d1b3SJooyung Han sums_of_2[cell][i] = vpaddlq_u8(vreinterpretq_u8_u16(
216*5f39d1b3SJooyung Han src_lines_intertwined_4x[2 * cell + outer].val[inner]));
217*5f39d1b3SJooyung Han }
218*5f39d1b3SJooyung Han }
219*5f39d1b3SJooyung Han }
220*5f39d1b3SJooyung Han uint16x8_t sums_of_4[kCells][2];
221*5f39d1b3SJooyung Han for (int i = 0; i < 2; i++) {
222*5f39d1b3SJooyung Han for (int cell = 0; cell < kCells; cell++) {
223*5f39d1b3SJooyung Han sums_of_4[cell][i] =
224*5f39d1b3SJooyung Han vaddq_u16(sums_of_2[cell][2 * i], sums_of_2[cell][2 * i + 1]);
225*5f39d1b3SJooyung Han }
226*5f39d1b3SJooyung Han }
227*5f39d1b3SJooyung Han uint16x8_t sums_of_8[kCells];
228*5f39d1b3SJooyung Han for (int cell = 0; cell < kCells; cell++) {
229*5f39d1b3SJooyung Han sums_of_8[cell] = vaddq_u16(sums_of_4[cell][0], sums_of_4[cell][1]);
230*5f39d1b3SJooyung Han }
231*5f39d1b3SJooyung Han
232*5f39d1b3SJooyung Han uint16x4_t sums_of_16[kCells];
233*5f39d1b3SJooyung Han for (int cell = 0; cell < kCells; cell++) {
234*5f39d1b3SJooyung Han sums_of_16[cell] = vadd_u16(vget_low_u16(sums_of_8[cell]),
235*5f39d1b3SJooyung Han vget_high_u16(sums_of_8[cell]));
236*5f39d1b3SJooyung Han }
237*5f39d1b3SJooyung Han // Update the sums_of_each_slice vector
238*5f39d1b3SJooyung Han for (int cell = 0; cell < kCells; cell++) {
239*5f39d1b3SJooyung Han int32x4_t s = vreinterpretq_s32_u32(vmovl_u16(sums_of_16[cell]));
240*5f39d1b3SJooyung Han std::int32_t* sums_of_each_slice_ptr =
241*5f39d1b3SJooyung Han dst->sums_of_each_slice() + start_width + 4 * cell;
242*5f39d1b3SJooyung Han vst1q_s32(sums_of_each_slice_ptr,
243*5f39d1b3SJooyung Han vaddq_s32(s, vld1q_s32(sums_of_each_slice_ptr)));
244*5f39d1b3SJooyung Han }
245*5f39d1b3SJooyung Han dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
246*5f39d1b3SJooyung Han }
247*5f39d1b3SJooyung Han };
248*5f39d1b3SJooyung Han
249*5f39d1b3SJooyung Han #ifdef GEMMLOWP_NEON_32
vpaddq_s16(int16x8_t a,int16x8_t b)250*5f39d1b3SJooyung Han inline int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
251*5f39d1b3SJooyung Han const int16x4_t c = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
252*5f39d1b3SJooyung Han const int16x4_t d = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
253*5f39d1b3SJooyung Han return vcombine_s16(c, d);
254*5f39d1b3SJooyung Han }
255*5f39d1b3SJooyung Han #endif
256*5f39d1b3SJooyung Han
257*5f39d1b3SJooyung Han template <int Width>
258*5f39d1b3SJooyung Han using Int8FastKernelFormat =
259*5f39d1b3SJooyung Han KernelSideFormatInt8<CellFormat<Width, 16, CellOrder::WidthMajor>, 1>;
260*5f39d1b3SJooyung Han
261*5f39d1b3SJooyung Han template <int Width>
262*5f39d1b3SJooyung Han class PackingRegisterBlock<WidthMajorUint8SideMap,
263*5f39d1b3SJooyung Han PackedSideBlock<Int8FastKernelFormat<Width>>>
264*5f39d1b3SJooyung Han : public PackingRegisterBlockBase<
265*5f39d1b3SJooyung Han WidthMajorUint8SideMap,
266*5f39d1b3SJooyung Han PackedSideBlock<Int8FastKernelFormat<Width>>> {
267*5f39d1b3SJooyung Han public:
268*5f39d1b3SJooyung Han static_assert(Width == 2 || Width == 4, "");
269*5f39d1b3SJooyung Han typedef Int8FastKernelFormat<Width> KernelSideFormat;
270*5f39d1b3SJooyung Han typedef typename KernelSideFormat::Cell CellFormat;
271*5f39d1b3SJooyung Han static const int kCells = KernelSideFormat::kCells;
272*5f39d1b3SJooyung Han static const int kCellWidth = CellFormat::kWidth;
273*5f39d1b3SJooyung Han static const int kKernelWidth = CellFormat::kWidth * kCells;
274*5f39d1b3SJooyung Han static const int kCellDepth = CellFormat::kDepth;
275*5f39d1b3SJooyung Han static const int kCellSize = CellFormat::kSize;
276*5f39d1b3SJooyung Han
Pack(PackedSideBlock<KernelSideFormat> * dst,int start_width)277*5f39d1b3SJooyung Han void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
278*5f39d1b3SJooyung Han std::int32_t* sums_ptr = dst->sums_of_each_slice() + start_width;
279*5f39d1b3SJooyung Han std::uint8_t* dst_ptr = dst->current_data();
280*5f39d1b3SJooyung Han const std::uint8_t* const src_ptr = this->complete_src_.data();
281*5f39d1b3SJooyung Han const int stride = this->complete_src_.stride();
282*5f39d1b3SJooyung Han // Load source WidthMajor data
283*5f39d1b3SJooyung Han uint8x16_t src_lines[Width];
284*5f39d1b3SJooyung Han for (int i = 0; i < Width; i++) {
285*5f39d1b3SJooyung Han src_lines[i] = vld1q_u8(src_ptr + i * stride);
286*5f39d1b3SJooyung Han }
287*5f39d1b3SJooyung Han const uint8x16_t sign_bit_dup = vdupq_n_u8(0x80);
288*5f39d1b3SJooyung Han for (int i = 0; i < Width; i++) {
289*5f39d1b3SJooyung Han src_lines[i] = veorq_u8(src_lines[i], sign_bit_dup);
290*5f39d1b3SJooyung Han }
291*5f39d1b3SJooyung Han for (int i = 0; i < Width; i++) {
292*5f39d1b3SJooyung Han vst1q_u8(dst_ptr + 16 * i, src_lines[i]);
293*5f39d1b3SJooyung Han }
294*5f39d1b3SJooyung Han int16x8_t sums2[Width];
295*5f39d1b3SJooyung Han for (int i = 0; i < Width; i++) {
296*5f39d1b3SJooyung Han const int8x8_t lo = vreinterpret_s8_u8(vget_low_u8(src_lines[i]));
297*5f39d1b3SJooyung Han const int8x8_t hi = vreinterpret_s8_u8(vget_high_u8(src_lines[i]));
298*5f39d1b3SJooyung Han sums2[i] = vaddl_s8(lo, hi);
299*5f39d1b3SJooyung Han }
300*5f39d1b3SJooyung Han int16x8_t sums4[Width / 2];
301*5f39d1b3SJooyung Han for (int i = 0; i < Width / 2; i++) {
302*5f39d1b3SJooyung Han sums4[i] = vpaddq_s16(sums2[2 * i], sums2[2 * i + 1]);
303*5f39d1b3SJooyung Han }
304*5f39d1b3SJooyung Han if (Width == 4) {
305*5f39d1b3SJooyung Han int32x4_t sum = vld1q_s32(sums_ptr);
306*5f39d1b3SJooyung Han int16x8_t sums8 = vpaddq_s16(sums4[0], sums4[1]);
307*5f39d1b3SJooyung Han sum = vpadalq_s16(sum, sums8);
308*5f39d1b3SJooyung Han vst1q_s32(sums_ptr, sum);
309*5f39d1b3SJooyung Han } else {
310*5f39d1b3SJooyung Han assert(Width == 2);
311*5f39d1b3SJooyung Han int32x2_t sum = vld1_s32(sums_ptr);
312*5f39d1b3SJooyung Han int16x4_t sums8 =
313*5f39d1b3SJooyung Han vpadd_s16(vget_low_s16(sums4[0]), vget_high_s16(sums4[0]));
314*5f39d1b3SJooyung Han sum = vpadal_s16(sum, sums8);
315*5f39d1b3SJooyung Han vst1_s32(sums_ptr, sum);
316*5f39d1b3SJooyung Han }
317*5f39d1b3SJooyung Han dst->seek_forward_n_cells(1);
318*5f39d1b3SJooyung Han }
319*5f39d1b3SJooyung Han };
320*5f39d1b3SJooyung Han
321*5f39d1b3SJooyung Han template <int Width>
322*5f39d1b3SJooyung Han using Int8InputsFastKernelFormat =
323*5f39d1b3SJooyung Han KernelSideFormatInt8Inputs<CellFormat<Width, 16, CellOrder::WidthMajor>, 1>;
324*5f39d1b3SJooyung Han
325*5f39d1b3SJooyung Han // Same as above, but for int8 inputs, avoiding the uint8 -> int8 conversion.
326*5f39d1b3SJooyung Han template <int Width>
327*5f39d1b3SJooyung Han class PackingRegisterBlock<WidthMajorInt8SideMap,
328*5f39d1b3SJooyung Han PackedSideBlock<Int8InputsFastKernelFormat<Width>>>
329*5f39d1b3SJooyung Han : public PackingRegisterBlockBase<
330*5f39d1b3SJooyung Han WidthMajorInt8SideMap,
331*5f39d1b3SJooyung Han PackedSideBlock<Int8InputsFastKernelFormat<Width>>> {
332*5f39d1b3SJooyung Han public:
333*5f39d1b3SJooyung Han static_assert(Width == 2 || Width == 4, "");
334*5f39d1b3SJooyung Han typedef Int8InputsFastKernelFormat<Width> KernelSideFormat;
335*5f39d1b3SJooyung Han typedef typename KernelSideFormat::Cell CellFormat;
336*5f39d1b3SJooyung Han static const int kCells = KernelSideFormat::kCells;
337*5f39d1b3SJooyung Han static const int kCellWidth = CellFormat::kWidth;
338*5f39d1b3SJooyung Han static const int kKernelWidth = CellFormat::kWidth * kCells;
339*5f39d1b3SJooyung Han static const int kCellDepth = CellFormat::kDepth;
340*5f39d1b3SJooyung Han static const int kCellSize = CellFormat::kSize;
341*5f39d1b3SJooyung Han
Pack(PackedSideBlock<KernelSideFormat> * dst,int start_width)342*5f39d1b3SJooyung Han void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
343*5f39d1b3SJooyung Han std::int32_t* sums_ptr = dst->sums_of_each_slice() + start_width;
344*5f39d1b3SJooyung Han std::int8_t* dst_ptr = reinterpret_cast<std::int8_t*>(dst->current_data());
345*5f39d1b3SJooyung Han const std::int8_t* const src_ptr = this->complete_src_.data();
346*5f39d1b3SJooyung Han const int stride = this->complete_src_.stride();
347*5f39d1b3SJooyung Han // Load source WidthMajor data
348*5f39d1b3SJooyung Han int8x16_t src_lines[Width];
349*5f39d1b3SJooyung Han for (int i = 0; i < Width; i++) {
350*5f39d1b3SJooyung Han src_lines[i] = vld1q_s8(src_ptr + i * stride);
351*5f39d1b3SJooyung Han }
352*5f39d1b3SJooyung Han for (int i = 0; i < Width; i++) {
353*5f39d1b3SJooyung Han vst1q_s8(dst_ptr + 16 * i, src_lines[i]);
354*5f39d1b3SJooyung Han }
355*5f39d1b3SJooyung Han int16x8_t sums2[Width];
356*5f39d1b3SJooyung Han for (int i = 0; i < Width; i++) {
357*5f39d1b3SJooyung Han const int8x8_t lo = vget_low_s8(src_lines[i]);
358*5f39d1b3SJooyung Han const int8x8_t hi = vget_high_s8(src_lines[i]);
359*5f39d1b3SJooyung Han sums2[i] = vaddl_s8(lo, hi);
360*5f39d1b3SJooyung Han }
361*5f39d1b3SJooyung Han int16x8_t sums4[Width / 2];
362*5f39d1b3SJooyung Han for (int i = 0; i < Width / 2; i++) {
363*5f39d1b3SJooyung Han sums4[i] = vpaddq_s16(sums2[2 * i], sums2[2 * i + 1]);
364*5f39d1b3SJooyung Han }
365*5f39d1b3SJooyung Han if (Width == 4) {
366*5f39d1b3SJooyung Han int32x4_t sum = vld1q_s32(sums_ptr);
367*5f39d1b3SJooyung Han int16x8_t sums8 = vpaddq_s16(sums4[0], sums4[1]);
368*5f39d1b3SJooyung Han sum = vpadalq_s16(sum, sums8);
369*5f39d1b3SJooyung Han vst1q_s32(sums_ptr, sum);
370*5f39d1b3SJooyung Han } else {
371*5f39d1b3SJooyung Han assert(Width == 2);
372*5f39d1b3SJooyung Han int32x2_t sum = vld1_s32(sums_ptr);
373*5f39d1b3SJooyung Han int16x4_t sums8 =
374*5f39d1b3SJooyung Han vpadd_s16(vget_low_s16(sums4[0]), vget_high_s16(sums4[0]));
375*5f39d1b3SJooyung Han sum = vpadal_s16(sum, sums8);
376*5f39d1b3SJooyung Han vst1_s32(sums_ptr, sum);
377*5f39d1b3SJooyung Han }
378*5f39d1b3SJooyung Han dst->seek_forward_n_cells(1);
379*5f39d1b3SJooyung Han }
380*5f39d1b3SJooyung Han };
381*5f39d1b3SJooyung Han
382*5f39d1b3SJooyung Han } // namespace gemmlowp
383*5f39d1b3SJooyung Han
384*5f39d1b3SJooyung Han #endif // GEMMLOWP_INTERNAL_PACK_NEON_H_
385