1 // Copyright 2021 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <immintrin.h>
9
10 #include <xnnpack/common.h>
11 #include <xnnpack/dwconv.h>
12 #include <xnnpack/math.h>
13 #include <xnnpack/unaligned.h>
14 #include <xnnpack/vcvt.h>
15 #include <xnnpack/vlrelu.h>
16
17
xnn_f32_dwconv2d_chw_ukernel_3x3p1__ssse3_2x4_acc2(size_t input_height,size_t input_width,const float * input,const float * weights,const float * zero,float * output,uint32_t padding_top,const union xnn_f32_chw_params params[restrict XNN_MIN_ELEMENTS (1)])18 void xnn_f32_dwconv2d_chw_ukernel_3x3p1__ssse3_2x4_acc2(
19 size_t input_height,
20 size_t input_width,
21 const float* input,
22 const float* weights,
23 const float* zero,
24 float* output,
25 uint32_t padding_top,
26 const union xnn_f32_chw_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
27 {
28 assert(input_height != 0);
29 assert(input_width != 0);
30 assert(input_width % sizeof(float) == 0);
31 assert(padding_top == 1);
32
33 const __m128 vmask = _mm_load_ps((const float*) params->sse.mask);
34 const __m128 vmax = _mm_load_ps(params->sse.max);
35 const __m128 vmin = _mm_load_ps(params->sse.min);
36
37 const __m128 vbias = _mm_load1_ps(weights);
38 const __m128 vk00 = _mm_load1_ps(weights + 1);
39 const __m128 vk01 = _mm_load1_ps(weights + 2);
40 const __m128 vk02 = _mm_load1_ps(weights + 3);
41 const __m128 vk10 = _mm_load1_ps(weights + 4);
42 const __m128 vk11 = _mm_load1_ps(weights + 5);
43 const __m128 vk12 = _mm_load1_ps(weights + 6);
44 const __m128 vk20 = _mm_load1_ps(weights + 7);
45 const __m128 vk21 = _mm_load1_ps(weights + 8);
46 const __m128 vk22 = _mm_load1_ps(weights + 9);
47
48 const size_t input_decrement = round_up_po2(input_width, 4 * sizeof(float));
49
50 const float* i0 = zero;
51 const float* i1 = input;
52 const float* i2 = (const float*) ((uintptr_t) i1 + input_width);
53 const float* i3 = (const float*) ((uintptr_t) i2 + input_width);
54
55 float* o0 = output;
56 float* o1 = (float*) ((uintptr_t) o0 + input_width);
57
58 size_t output_height = input_height;
59 do {
60 if XNN_UNPREDICTABLE(output_height < 2) {
61 i2 = zero;
62 o1 = o0;
63 }
64 if XNN_UNPREDICTABLE(output_height < 3) {
65 i3 = zero;
66 }
67
68 __m128 vi0x0123 = _mm_setzero_ps();
69 __m128 vi1x0123 = _mm_setzero_ps();
70 __m128 vi2x0123 = _mm_setzero_ps();
71 __m128 vi3x0123 = _mm_setzero_ps();
72
73 __m128 vi0x4567 = _mm_loadu_ps(i0);
74 i0 += 4;
75 __m128 vi1x4567 = _mm_loadu_ps(i1);
76 i1 += 4;
77 __m128 vi2x4567 = _mm_loadu_ps(i2);
78 i2 += 4;
79 __m128 vi3x4567 = _mm_loadu_ps(i3);
80 i3 += 4;
81
82 size_t w = input_width;
83 for (; w > 4 * sizeof(float); w -= 4 * sizeof(float)) {
84 const __m128 vi0x89AB = _mm_loadu_ps(i0);
85 i0 += 4;
86 const __m128 vi1x89AB = _mm_loadu_ps(i1);
87 i1 += 4;
88 const __m128 vi2x89AB = _mm_loadu_ps(i2);
89 i2 += 4;
90 const __m128 vi3x89AB = _mm_loadu_ps(i3);
91 i3 += 4;
92
93 __m128 vo0p0 = _mm_add_ps(vbias, _mm_mul_ps(vi0x4567, vk01));
94 __m128 vo1p0 = _mm_add_ps(vbias, _mm_mul_ps(vi1x4567, vk01));
95 __m128 vo0p1 = _mm_mul_ps(vi1x4567, vk11);
96 __m128 vo1p1 = _mm_mul_ps(vi2x4567, vk11);
97 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x4567, vk21));
98 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x4567, vk21));
99
100 const __m128 vi0x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi0x4567), _mm_castps_si128(vi0x0123), 12));
101 const __m128 vi1x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi1x4567), _mm_castps_si128(vi1x0123), 12));
102 const __m128 vi2x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi2x4567), _mm_castps_si128(vi2x0123), 12));
103 const __m128 vi3x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi3x4567), _mm_castps_si128(vi3x0123), 12));
104
105 vo0p1 = _mm_add_ps(vo0p1, _mm_mul_ps(vi0x3456, vk00));
106 vo1p1 = _mm_add_ps(vo1p1, _mm_mul_ps(vi1x3456, vk00));
107 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x3456, vk10));
108 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x3456, vk10));
109 vo0p1 = _mm_add_ps(vo0p1, _mm_mul_ps(vi2x3456, vk20));
110 vo1p1 = _mm_add_ps(vo1p1, _mm_mul_ps(vi3x3456, vk20));
111
112 vi0x0123 = vi0x4567;
113 vi1x0123 = vi1x4567;
114 vi2x0123 = vi2x4567;
115 vi3x0123 = vi3x4567;
116
117 const __m128 vi0x5678 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi0x89AB), _mm_castps_si128(vi0x4567), 4));
118 const __m128 vi1x5678 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi1x89AB), _mm_castps_si128(vi1x4567), 4));
119 const __m128 vi2x5678 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi2x89AB), _mm_castps_si128(vi2x4567), 4));
120 const __m128 vi3x5678 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi3x89AB), _mm_castps_si128(vi3x4567), 4));
121
122 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi0x5678, vk02));
123 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi1x5678, vk02));
124 vo0p1 = _mm_add_ps(vo0p1, _mm_mul_ps(vi1x5678, vk12));
125 vo1p1 = _mm_add_ps(vo1p1, _mm_mul_ps(vi2x5678, vk12));
126 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x5678, vk22));
127 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x5678, vk22));
128
129 vi0x4567 = vi0x89AB;
130 vi1x4567 = vi1x89AB;
131 vi2x4567 = vi2x89AB;
132 vi3x4567 = vi3x89AB;
133
134 vo0p0 = _mm_add_ps(vo0p0, vo0p1);
135 vo1p0 = _mm_add_ps(vo1p0, vo1p1);
136
137 __m128 vo0 = _mm_max_ps(vo0p0, vmin);
138 __m128 vo1 = _mm_max_ps(vo1p0, vmin);
139
140 vo0 = _mm_min_ps(vo0, vmax);
141 vo1 = _mm_min_ps(vo1, vmax);
142
143 _mm_storeu_ps(o1, vo1);
144 o1 += 4;
145 _mm_storeu_ps(o0, vo0);
146 o0 += 4;
147 }
148 // Always process the last block of 1..4 pixels.
149 assert(w >= 1 * sizeof(float));
150 assert(w <= 4 * sizeof(float));
151 {
152 vi0x4567 = _mm_and_ps(vmask, vi0x4567);
153 vi1x4567 = _mm_and_ps(vmask, vi1x4567);
154 vi2x4567 = _mm_and_ps(vmask, vi2x4567);
155 vi3x4567 = _mm_and_ps(vmask, vi3x4567);
156
157 __m128 vo0p0 = _mm_add_ps(vbias, _mm_mul_ps(vi0x4567, vk01));
158 __m128 vo1p0 = _mm_add_ps(vbias, _mm_mul_ps(vi1x4567, vk01));
159 __m128 vo0p1 = _mm_mul_ps(vi1x4567, vk11);
160 __m128 vo1p1 = _mm_mul_ps(vi2x4567, vk11);
161 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x4567, vk21));
162 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x4567, vk21));
163
164 const __m128 vi0x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi0x4567), _mm_castps_si128(vi0x0123), 12));
165 const __m128 vi1x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi1x4567), _mm_castps_si128(vi1x0123), 12));
166 const __m128 vi2x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi2x4567), _mm_castps_si128(vi2x0123), 12));
167 const __m128 vi3x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi3x4567), _mm_castps_si128(vi3x0123), 12));
168
169 vo0p1 = _mm_add_ps(vo0p1, _mm_mul_ps(vi0x3456, vk00));
170 vo1p1 = _mm_add_ps(vo1p1, _mm_mul_ps(vi1x3456, vk00));
171 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x3456, vk10));
172 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x3456, vk10));
173 vo0p1 = _mm_add_ps(vo0p1, _mm_mul_ps(vi2x3456, vk20));
174 vo1p1 = _mm_add_ps(vo1p1, _mm_mul_ps(vi3x3456, vk20));
175
176 const __m128i vzero = _mm_setzero_si128();
177 const __m128 vi0x5678 = _mm_castsi128_ps(_mm_alignr_epi8(vzero, _mm_castps_si128(vi0x4567), 4));
178 const __m128 vi1x5678 = _mm_castsi128_ps(_mm_alignr_epi8(vzero, _mm_castps_si128(vi1x4567), 4));
179 const __m128 vi2x5678 = _mm_castsi128_ps(_mm_alignr_epi8(vzero, _mm_castps_si128(vi2x4567), 4));
180 const __m128 vi3x5678 = _mm_castsi128_ps(_mm_alignr_epi8(vzero, _mm_castps_si128(vi3x4567), 4));
181
182 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi0x5678, vk02));
183 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi1x5678, vk02));
184 vo0p1 = _mm_add_ps(vo0p1, _mm_mul_ps(vi1x5678, vk12));
185 vo1p1 = _mm_add_ps(vo1p1, _mm_mul_ps(vi2x5678, vk12));
186 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x5678, vk22));
187 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x5678, vk22));
188
189 vo0p0 = _mm_add_ps(vo0p0, vo0p1);
190 vo1p0 = _mm_add_ps(vo1p0, vo1p1);
191
192 __m128 vo0 = _mm_max_ps(vo0p0, vmin);
193 __m128 vo1 = _mm_max_ps(vo1p0, vmin);
194
195 vo0 = _mm_min_ps(vo0, vmax);
196 vo1 = _mm_min_ps(vo1, vmax);
197
198 if XNN_LIKELY(w == 4 * sizeof(float)) {
199 _mm_storeu_ps(o1, vo1);
200 o1 += 4;
201 _mm_storeu_ps(o0, vo0);
202 o0 += 4;
203 } else {
204 if (w & (2 * sizeof(float))) {
205 _mm_storel_pi((__m64*) o1, vo1);
206 o1 += 2;
207 _mm_storel_pi((__m64*) o0, vo0);
208 o0 += 2;
209
210 vo0 = _mm_movehl_ps(vo0, vo0);
211 vo1 = _mm_movehl_ps(vo1, vo1);
212 }
213 if (w & (1 * sizeof(float))) {
214 _mm_store_ss(o1, vo1);
215 o1 += 1;
216 _mm_store_ss(o0, vo0);
217 o0 += 1;
218 }
219 }
220 }
221
222 i0 = (const float*) ((uintptr_t) i2 - input_decrement);
223 i1 = (const float*) ((uintptr_t) i3 - input_decrement);
224 i2 = (const float*) ((uintptr_t) i1 + input_width);
225 i3 = (const float*) ((uintptr_t) i2 + input_width);
226
227 o0 = o1;
228 o1 = (float*) ((uintptr_t) o0 + input_width);
229
230 output_height = doz(output_height, 2);
231 } while (output_height != 0);
232 }
233
xnn_qs8_vcvt_ukernel__ssse3_x32(size_t n,const int8_t * x,int8_t * y,const union xnn_qs8_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])234 void xnn_qs8_vcvt_ukernel__ssse3_x32(
235 size_t n,
236 const int8_t* x,
237 int8_t* y,
238 const union xnn_qs8_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
239 {
240 assert(n != 0);
241 assert(n % sizeof(int8_t) == 0);
242 assert(x != NULL);
243 assert(y != NULL);
244
245 const __m128i vinput_zero_point = _mm_load_si128((const __m128i*) params->ssse3.input_zero_point);
246 const __m128i vmultiplier = _mm_load_si128((const __m128i*) params->ssse3.multiplier);
247 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->ssse3.output_zero_point);
248 for (; n >= 32 * sizeof(int8_t); n -= 32 * sizeof(int8_t)) {
249 const __m128i vx0 = _mm_loadu_si128((const __m128i*) x);
250 const __m128i vx1 = _mm_loadu_si128((const __m128i*) (x + 16));
251 x += 32;
252
253 const __m128i vm0 = _mm_cmpgt_epi8(_mm_setzero_si128(), vx0);
254 __m128i vacc0 = _mm_unpacklo_epi8(vx0, vm0);
255 __m128i vacc1 = _mm_unpackhi_epi8(vx0, vm0);
256 const __m128i vm1 = _mm_cmpgt_epi8(_mm_setzero_si128(), vx1);
257 __m128i vacc2 = _mm_unpacklo_epi8(vx1, vm1);
258 __m128i vacc3 = _mm_unpackhi_epi8(vx1, vm1);
259
260 vacc0 = _mm_sub_epi16(vinput_zero_point, vacc0);
261 vacc1 = _mm_sub_epi16(vinput_zero_point, vacc1);
262 vacc2 = _mm_sub_epi16(vinput_zero_point, vacc2);
263 vacc3 = _mm_sub_epi16(vinput_zero_point, vacc3);
264
265 vacc0 = _mm_slli_epi16(vacc0, 7);
266 vacc1 = _mm_slli_epi16(vacc1, 7);
267 vacc2 = _mm_slli_epi16(vacc2, 7);
268 vacc3 = _mm_slli_epi16(vacc3, 7);
269
270 vacc0 = _mm_mulhrs_epi16(vacc0, vmultiplier);
271 vacc1 = _mm_mulhrs_epi16(vacc1, vmultiplier);
272 vacc2 = _mm_mulhrs_epi16(vacc2, vmultiplier);
273 vacc3 = _mm_mulhrs_epi16(vacc3, vmultiplier);
274
275 vacc0 = _mm_adds_epi16(vacc0, voutput_zero_point);
276 vacc1 = _mm_adds_epi16(vacc1, voutput_zero_point);
277 vacc2 = _mm_adds_epi16(vacc2, voutput_zero_point);
278 vacc3 = _mm_adds_epi16(vacc3, voutput_zero_point);
279
280 const __m128i vy0 = _mm_packs_epi16(vacc0, vacc1);
281 const __m128i vy1 = _mm_packs_epi16(vacc2, vacc3);
282
283 _mm_storeu_si128((__m128i*) y, vy0);
284 _mm_storeu_si128((__m128i*) (y + 16), vy1);
285 y += 32;
286 }
287 for (; n >= 16 * sizeof(int8_t); n -= 16 * sizeof(int8_t)) {
288 const __m128i vx = _mm_loadu_si128((const __m128i*) x);
289 x += 16;
290
291 const __m128i vm = _mm_cmpgt_epi8(_mm_setzero_si128(), vx);
292 __m128i vacc_lo = _mm_unpacklo_epi8(vx, vm);
293 __m128i vacc_hi = _mm_unpackhi_epi8(vx, vm);
294 vacc_lo = _mm_sub_epi16(vinput_zero_point, vacc_lo);
295 vacc_hi = _mm_sub_epi16(vinput_zero_point, vacc_hi);
296 vacc_lo = _mm_slli_epi16(vacc_lo, 7);
297 vacc_hi = _mm_slli_epi16(vacc_hi, 7);
298 vacc_lo = _mm_mulhrs_epi16(vacc_lo, vmultiplier);
299 vacc_hi = _mm_mulhrs_epi16(vacc_hi, vmultiplier);
300 vacc_lo = _mm_adds_epi16(vacc_lo, voutput_zero_point);
301 vacc_hi = _mm_adds_epi16(vacc_hi, voutput_zero_point);
302
303 const __m128i vy = _mm_packs_epi16(vacc_lo, vacc_hi);
304 _mm_storeu_si128((__m128i*) y, vy);
305 y += 16;
306 }
307 if XNN_UNLIKELY(n != 0) {
308 assert(n >= 1 * sizeof(int8_t));
309 assert(n <= 15 * sizeof(int8_t));
310
311 const __m128i vx = _mm_loadu_si128((const __m128i*) x);
312
313 const __m128i vm = _mm_cmpgt_epi8(_mm_setzero_si128(), vx);
314 __m128i vacc_lo = _mm_unpacklo_epi8(vx, vm);
315 __m128i vacc_hi = _mm_unpackhi_epi8(vx, vm);
316 vacc_lo = _mm_sub_epi16(vinput_zero_point, vacc_lo);
317 vacc_hi = _mm_sub_epi16(vinput_zero_point, vacc_hi);
318 vacc_lo = _mm_slli_epi16(vacc_lo, 7);
319 vacc_hi = _mm_slli_epi16(vacc_hi, 7);
320 vacc_lo = _mm_mulhrs_epi16(vacc_lo, vmultiplier);
321 vacc_hi = _mm_mulhrs_epi16(vacc_hi, vmultiplier);
322 vacc_lo = _mm_adds_epi16(vacc_lo, voutput_zero_point);
323 vacc_hi = _mm_adds_epi16(vacc_hi, voutput_zero_point);
324
325 __m128i vy = _mm_packs_epi16(vacc_lo, vacc_hi);
326 if (n & (8 * sizeof(int8_t))) {
327 _mm_storel_epi64((__m128i*) y, vy);
328 vy = _mm_unpackhi_epi64(vy, vy);
329 y += 8;
330 }
331 if (n & (4 * sizeof(int8_t))) {
332 unaligned_store_u32(y, (uint32_t) _mm_cvtsi128_si32(vy));
333 vy = _mm_srli_epi64(vy, 32);
334 y += 4;
335 }
336 uint32_t vy_lo = (uint32_t) _mm_cvtsi128_si32(vy);
337 if (n & (2 * sizeof(int8_t))) {
338 unaligned_store_u16(y, (uint16_t) vy_lo);
339 vy_lo >>= 16;
340 y += 2;
341 }
342 if (n & (1 * sizeof(int8_t))) {
343 *y = (int8_t) vy_lo;
344 }
345 }
346 }
347
xnn_qs8_vlrelu_ukernel__ssse3_x32(size_t n,const int8_t * x,int8_t * y,const union xnn_qs8_lrelu_params params[restrict XNN_MIN_ELEMENTS (1)])348 void xnn_qs8_vlrelu_ukernel__ssse3_x32(
349 size_t n,
350 const int8_t* x,
351 int8_t* y,
352 const union xnn_qs8_lrelu_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
353 {
354 assert(n != 0);
355 assert(n % sizeof(int8_t) == 0);
356 assert(x != NULL);
357 assert(y != NULL);
358
359 const __m128i vinput_zero_point = _mm_load_si128((const __m128i*) params->sse2.input_zero_point);
360 const __m128i vmultiplier_diff = _mm_load_si128((const __m128i*) params->sse2.multiplier_diff);
361 const __m128i vmultiplier_base = _mm_load_si128((const __m128i*) params->sse2.multiplier_base);
362 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->sse2.output_zero_point);
363 for (; n >= 32 * sizeof(int8_t); n -= 32 * sizeof(int8_t)) {
364 const __m128i vx0 = _mm_loadu_si128((const __m128i*) x);
365 const __m128i vx1 = _mm_loadu_si128((const __m128i*) (x + 16));
366 x += 32;
367
368 const __m128i vm0 = _mm_cmpgt_epi8(_mm_setzero_si128(), vx0);
369 __m128i vacc0 = _mm_unpacklo_epi8(vx0, vm0);
370 __m128i vacc1 = _mm_unpackhi_epi8(vx0, vm0);
371 const __m128i vm1 = _mm_cmpgt_epi8(_mm_setzero_si128(), vx1);
372 __m128i vacc2 = _mm_unpacklo_epi8(vx1, vm1);
373 __m128i vacc3 = _mm_unpackhi_epi8(vx1, vm1);
374
375 __m128i vmultiplier0 = _mm_cmpgt_epi16(vacc0, vinput_zero_point);
376 vacc0 = _mm_sub_epi16(vinput_zero_point, vacc0);
377 __m128i vmultiplier1 = _mm_cmpgt_epi16(vacc1, vinput_zero_point);
378 vacc1 = _mm_sub_epi16(vinput_zero_point, vacc1);
379 __m128i vmultiplier2 = _mm_cmpgt_epi16(vacc2, vinput_zero_point);
380 vacc2 = _mm_sub_epi16(vinput_zero_point, vacc2);
381 __m128i vmultiplier3 = _mm_cmpgt_epi16(vacc3, vinput_zero_point);
382 vacc3 = _mm_sub_epi16(vinput_zero_point, vacc3);
383
384 vmultiplier0 = _mm_and_si128(vmultiplier0, vmultiplier_diff);
385 vacc0 = _mm_slli_epi16(vacc0, 7);
386 vmultiplier0 = _mm_xor_si128(vmultiplier0, vmultiplier_base);
387 vmultiplier1 = _mm_and_si128(vmultiplier1, vmultiplier_diff);
388 vacc1 = _mm_slli_epi16(vacc1, 7);
389 vmultiplier1 = _mm_xor_si128(vmultiplier1, vmultiplier_base);
390 vmultiplier2 = _mm_and_si128(vmultiplier2, vmultiplier_diff);
391 vacc2 = _mm_slli_epi16(vacc2, 7);
392 vmultiplier2 = _mm_xor_si128(vmultiplier2, vmultiplier_base);
393 vmultiplier3 = _mm_and_si128(vmultiplier3, vmultiplier_diff);
394 vacc3 = _mm_slli_epi16(vacc3, 7);
395 vmultiplier3 = _mm_xor_si128(vmultiplier3, vmultiplier_base);
396
397 vacc0 = _mm_mulhrs_epi16(vacc0, vmultiplier0);
398 vacc1 = _mm_mulhrs_epi16(vacc1, vmultiplier1);
399 vacc2 = _mm_mulhrs_epi16(vacc2, vmultiplier2);
400 vacc3 = _mm_mulhrs_epi16(vacc3, vmultiplier3);
401
402 vacc0 = _mm_adds_epi16(vacc0, voutput_zero_point);
403 vacc1 = _mm_adds_epi16(vacc1, voutput_zero_point);
404 vacc2 = _mm_adds_epi16(vacc2, voutput_zero_point);
405 vacc3 = _mm_adds_epi16(vacc3, voutput_zero_point);
406
407 const __m128i vy0 = _mm_packs_epi16(vacc0, vacc1);
408 const __m128i vy1 = _mm_packs_epi16(vacc2, vacc3);
409
410 _mm_storeu_si128((__m128i*) y, vy0);
411 _mm_storeu_si128((__m128i*) (y + 16), vy1);
412 y += 32;
413 }
414 for (; n >= 16 * sizeof(int8_t); n -= 16 * sizeof(int8_t)) {
415 const __m128i vx = _mm_loadu_si128((const __m128i*) x);
416 x += 16;
417
418 const __m128i vm = _mm_cmpgt_epi8(_mm_setzero_si128(), vx);
419 __m128i vacc_lo = _mm_unpacklo_epi8(vx, vm);
420 __m128i vacc_hi = _mm_unpackhi_epi8(vx, vm);
421 __m128i vmultiplier_lo = _mm_cmpgt_epi16(vacc_lo, vinput_zero_point);
422 __m128i vmultiplier_hi = _mm_cmpgt_epi16(vacc_hi, vinput_zero_point);
423 vacc_lo = _mm_sub_epi16(vinput_zero_point, vacc_lo);
424 vacc_hi = _mm_sub_epi16(vinput_zero_point, vacc_hi);
425 vmultiplier_lo = _mm_and_si128(vmultiplier_lo, vmultiplier_diff);
426 vmultiplier_hi = _mm_and_si128(vmultiplier_hi, vmultiplier_diff);
427 vacc_lo = _mm_slli_epi16(vacc_lo, 7);
428 vacc_hi = _mm_slli_epi16(vacc_hi, 7);
429 vmultiplier_lo = _mm_xor_si128(vmultiplier_lo, vmultiplier_base);
430 vmultiplier_hi = _mm_xor_si128(vmultiplier_hi, vmultiplier_base);
431 vacc_lo = _mm_mulhrs_epi16(vacc_lo, vmultiplier_lo);
432 vacc_hi = _mm_mulhrs_epi16(vacc_hi, vmultiplier_hi);
433 vacc_lo = _mm_adds_epi16(vacc_lo, voutput_zero_point);
434 vacc_hi = _mm_adds_epi16(vacc_hi, voutput_zero_point);
435
436 const __m128i vy = _mm_packs_epi16(vacc_lo, vacc_hi);
437 _mm_storeu_si128((__m128i*) y, vy);
438 y += 16;
439 }
440 if XNN_UNLIKELY(n != 0) {
441 assert(n >= 1 * sizeof(int8_t));
442 assert(n <= 15 * sizeof(int8_t));
443
444 const __m128i vx = _mm_loadu_si128((const __m128i*) x);
445
446 const __m128i vm = _mm_cmpgt_epi8(_mm_setzero_si128(), vx);
447 __m128i vacc_lo = _mm_unpacklo_epi8(vx, vm);
448 __m128i vacc_hi = _mm_unpackhi_epi8(vx, vm);
449 __m128i vmultiplier_lo = _mm_cmpgt_epi16(vacc_lo, vinput_zero_point);
450 __m128i vmultiplier_hi = _mm_cmpgt_epi16(vacc_hi, vinput_zero_point);
451 vacc_lo = _mm_sub_epi16(vinput_zero_point, vacc_lo);
452 vacc_hi = _mm_sub_epi16(vinput_zero_point, vacc_hi);
453 vmultiplier_lo = _mm_and_si128(vmultiplier_lo, vmultiplier_diff);
454 vmultiplier_hi = _mm_and_si128(vmultiplier_hi, vmultiplier_diff);
455 vacc_lo = _mm_slli_epi16(vacc_lo, 7);
456 vacc_hi = _mm_slli_epi16(vacc_hi, 7);
457 vmultiplier_lo = _mm_xor_si128(vmultiplier_lo, vmultiplier_base);
458 vmultiplier_hi = _mm_xor_si128(vmultiplier_hi, vmultiplier_base);
459 vacc_lo = _mm_mulhrs_epi16(vacc_lo, vmultiplier_lo);
460 vacc_hi = _mm_mulhrs_epi16(vacc_hi, vmultiplier_hi);
461 vacc_lo = _mm_adds_epi16(vacc_lo, voutput_zero_point);
462 vacc_hi = _mm_adds_epi16(vacc_hi, voutput_zero_point);
463
464 __m128i vy = _mm_packs_epi16(vacc_lo, vacc_hi);
465 if (n & (8 * sizeof(int8_t))) {
466 _mm_storel_epi64((__m128i*) y, vy);
467 vy = _mm_unpackhi_epi64(vy, vy);
468 y += 8;
469 }
470 if (n & (4 * sizeof(int8_t))) {
471 unaligned_store_u32(y, (uint32_t) _mm_cvtsi128_si32(vy));
472 vy = _mm_srli_epi64(vy, 32);
473 y += 4;
474 }
475 uint32_t vy_lo = (uint32_t) _mm_cvtsi128_si32(vy);
476 if (n & (2 * sizeof(int8_t))) {
477 unaligned_store_u16(y, (uint16_t) vy_lo);
478 vy_lo >>= 16;
479 y += 2;
480 }
481 if (n & (1 * sizeof(int8_t))) {
482 *y = (int8_t) vy_lo;
483 }
484 }
485 }
486
xnn_qu8_vcvt_ukernel__ssse3_x32(size_t n,const uint8_t * x,uint8_t * y,const union xnn_qu8_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])487 void xnn_qu8_vcvt_ukernel__ssse3_x32(
488 size_t n,
489 const uint8_t* x,
490 uint8_t* y,
491 const union xnn_qu8_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
492 {
493 assert(n != 0);
494 assert(n % sizeof(uint8_t) == 0);
495 assert(x != NULL);
496 assert(y != NULL);
497
498 const __m128i vinput_zero_point = _mm_load_si128((const __m128i*) params->ssse3.input_zero_point);
499 const __m128i vmultiplier = _mm_load_si128((const __m128i*) params->ssse3.multiplier);
500 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->ssse3.output_zero_point);
501 const __m128i vzero = _mm_setzero_si128();
502 for (; n >= 32 * sizeof(uint8_t); n -= 32 * sizeof(uint8_t)) {
503 const __m128i vx0 = _mm_loadu_si128((const __m128i*) x);
504 const __m128i vx1 = _mm_loadu_si128((const __m128i*) (x + 16));
505 x += 32;
506
507 __m128i vacc0 = _mm_unpacklo_epi8(vx0, vzero);
508 __m128i vacc1 = _mm_unpackhi_epi8(vx0, vzero);
509 __m128i vacc2 = _mm_unpacklo_epi8(vx1, vzero);
510 __m128i vacc3 = _mm_unpackhi_epi8(vx1, vzero);
511
512 vacc0 = _mm_sub_epi16(vinput_zero_point, vacc0);
513 vacc1 = _mm_sub_epi16(vinput_zero_point, vacc1);
514 vacc2 = _mm_sub_epi16(vinput_zero_point, vacc2);
515 vacc3 = _mm_sub_epi16(vinput_zero_point, vacc3);
516
517 vacc0 = _mm_slli_epi16(vacc0, 7);
518 vacc1 = _mm_slli_epi16(vacc1, 7);
519 vacc2 = _mm_slli_epi16(vacc2, 7);
520 vacc3 = _mm_slli_epi16(vacc3, 7);
521
522 vacc0 = _mm_mulhrs_epi16(vacc0, vmultiplier);
523 vacc1 = _mm_mulhrs_epi16(vacc1, vmultiplier);
524 vacc2 = _mm_mulhrs_epi16(vacc2, vmultiplier);
525 vacc3 = _mm_mulhrs_epi16(vacc3, vmultiplier);
526
527 vacc0 = _mm_adds_epi16(vacc0, voutput_zero_point);
528 vacc1 = _mm_adds_epi16(vacc1, voutput_zero_point);
529 vacc2 = _mm_adds_epi16(vacc2, voutput_zero_point);
530 vacc3 = _mm_adds_epi16(vacc3, voutput_zero_point);
531
532 const __m128i vy0 = _mm_packus_epi16(vacc0, vacc1);
533 const __m128i vy1 = _mm_packus_epi16(vacc2, vacc3);
534
535 _mm_storeu_si128((__m128i*) y, vy0);
536 _mm_storeu_si128((__m128i*) (y + 16), vy1);
537 y += 32;
538 }
539 for (; n >= 16 * sizeof(uint8_t); n -= 16 * sizeof(uint8_t)) {
540 const __m128i vx = _mm_loadu_si128((const __m128i*) x);
541 x += 16;
542
543 __m128i vacc_lo = _mm_unpacklo_epi8(vx, vzero);
544 __m128i vacc_hi = _mm_unpackhi_epi8(vx, vzero);
545 vacc_lo = _mm_sub_epi16(vinput_zero_point, vacc_lo);
546 vacc_hi = _mm_sub_epi16(vinput_zero_point, vacc_hi);
547 vacc_lo = _mm_slli_epi16(vacc_lo, 7);
548 vacc_hi = _mm_slli_epi16(vacc_hi, 7);
549 vacc_lo = _mm_mulhrs_epi16(vacc_lo, vmultiplier);
550 vacc_hi = _mm_mulhrs_epi16(vacc_hi, vmultiplier);
551 vacc_lo = _mm_adds_epi16(vacc_lo, voutput_zero_point);
552 vacc_hi = _mm_adds_epi16(vacc_hi, voutput_zero_point);
553
554 const __m128i vy = _mm_packus_epi16(vacc_lo, vacc_hi);
555 _mm_storeu_si128((__m128i*) y, vy);
556 y += 16;
557 }
558 if XNN_UNLIKELY(n != 0) {
559 assert(n >= 1 * sizeof(uint8_t));
560 assert(n <= 15 * sizeof(uint8_t));
561
562 const __m128i vx = _mm_loadu_si128((const __m128i*) x);
563
564 __m128i vacc_lo = _mm_unpacklo_epi8(vx, vzero);
565 __m128i vacc_hi = _mm_unpackhi_epi8(vx, vzero);
566 vacc_lo = _mm_sub_epi16(vinput_zero_point, vacc_lo);
567 vacc_hi = _mm_sub_epi16(vinput_zero_point, vacc_hi);
568 vacc_lo = _mm_slli_epi16(vacc_lo, 7);
569 vacc_hi = _mm_slli_epi16(vacc_hi, 7);
570 vacc_lo = _mm_mulhrs_epi16(vacc_lo, vmultiplier);
571 vacc_hi = _mm_mulhrs_epi16(vacc_hi, vmultiplier);
572 vacc_lo = _mm_adds_epi16(vacc_lo, voutput_zero_point);
573 vacc_hi = _mm_adds_epi16(vacc_hi, voutput_zero_point);
574
575 __m128i vy = _mm_packus_epi16(vacc_lo, vacc_hi);
576 if (n & (8 * sizeof(uint8_t))) {
577 _mm_storel_epi64((__m128i*) y, vy);
578 vy = _mm_unpackhi_epi64(vy, vy);
579 y += 8;
580 }
581 if (n & (4 * sizeof(uint8_t))) {
582 unaligned_store_u32(y, (uint32_t) _mm_cvtsi128_si32(vy));
583 vy = _mm_srli_epi64(vy, 32);
584 y += 4;
585 }
586 uint32_t vy_lo = (uint32_t) _mm_cvtsi128_si32(vy);
587 if (n & (2 * sizeof(uint8_t))) {
588 unaligned_store_u16(y, (uint16_t) vy_lo);
589 vy_lo >>= 16;
590 y += 2;
591 }
592 if (n & (1 * sizeof(uint8_t))) {
593 *y = (uint8_t) vy_lo;
594 }
595 }
596 }
597
xnn_qu8_vlrelu_ukernel__ssse3_x32(size_t n,const uint8_t * x,uint8_t * y,const union xnn_qu8_lrelu_params params[restrict XNN_MIN_ELEMENTS (1)])598 void xnn_qu8_vlrelu_ukernel__ssse3_x32(
599 size_t n,
600 const uint8_t* x,
601 uint8_t* y,
602 const union xnn_qu8_lrelu_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
603 {
604 assert(n != 0);
605 assert(n % sizeof(uint8_t) == 0);
606 assert(x != NULL);
607 assert(y != NULL);
608
609 const __m128i vinput_zero_point = _mm_load_si128((const __m128i*) params->sse2.input_zero_point);
610 const __m128i vmultiplier_diff = _mm_load_si128((const __m128i*) params->sse2.multiplier_diff);
611 const __m128i vmultiplier_base = _mm_load_si128((const __m128i*) params->sse2.multiplier_base);
612 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->sse2.output_zero_point);
613 const __m128i vzero = _mm_setzero_si128();
614 for (; n >= 32 * sizeof(uint8_t); n -= 32 * sizeof(uint8_t)) {
615 const __m128i vx0 = _mm_loadu_si128((const __m128i*) x);
616 const __m128i vx1 = _mm_loadu_si128((const __m128i*) (x + 16));
617 x += 32;
618
619 __m128i vacc0 = _mm_unpacklo_epi8(vx0, vzero);
620 __m128i vacc1 = _mm_unpackhi_epi8(vx0, vzero);
621 __m128i vacc2 = _mm_unpacklo_epi8(vx1, vzero);
622 __m128i vacc3 = _mm_unpackhi_epi8(vx1, vzero);
623
624 __m128i vmultiplier0 = _mm_cmpgt_epi16(vacc0, vinput_zero_point);
625 vacc0 = _mm_sub_epi16(vinput_zero_point, vacc0);
626 __m128i vmultiplier1 = _mm_cmpgt_epi16(vacc1, vinput_zero_point);
627 vacc1 = _mm_sub_epi16(vinput_zero_point, vacc1);
628 __m128i vmultiplier2 = _mm_cmpgt_epi16(vacc2, vinput_zero_point);
629 vacc2 = _mm_sub_epi16(vinput_zero_point, vacc2);
630 __m128i vmultiplier3 = _mm_cmpgt_epi16(vacc3, vinput_zero_point);
631 vacc3 = _mm_sub_epi16(vinput_zero_point, vacc3);
632
633 vmultiplier0 = _mm_and_si128(vmultiplier0, vmultiplier_diff);
634 vacc0 = _mm_slli_epi16(vacc0, 7);
635 vmultiplier0 = _mm_xor_si128(vmultiplier0, vmultiplier_base);
636 vmultiplier1 = _mm_and_si128(vmultiplier1, vmultiplier_diff);
637 vacc1 = _mm_slli_epi16(vacc1, 7);
638 vmultiplier1 = _mm_xor_si128(vmultiplier1, vmultiplier_base);
639 vmultiplier2 = _mm_and_si128(vmultiplier2, vmultiplier_diff);
640 vacc2 = _mm_slli_epi16(vacc2, 7);
641 vmultiplier2 = _mm_xor_si128(vmultiplier2, vmultiplier_base);
642 vmultiplier3 = _mm_and_si128(vmultiplier3, vmultiplier_diff);
643 vacc3 = _mm_slli_epi16(vacc3, 7);
644 vmultiplier3 = _mm_xor_si128(vmultiplier3, vmultiplier_base);
645
646 vacc0 = _mm_mulhrs_epi16(vacc0, vmultiplier0);
647 vacc1 = _mm_mulhrs_epi16(vacc1, vmultiplier1);
648 vacc2 = _mm_mulhrs_epi16(vacc2, vmultiplier2);
649 vacc3 = _mm_mulhrs_epi16(vacc3, vmultiplier3);
650
651 vacc0 = _mm_adds_epi16(vacc0, voutput_zero_point);
652 vacc1 = _mm_adds_epi16(vacc1, voutput_zero_point);
653 vacc2 = _mm_adds_epi16(vacc2, voutput_zero_point);
654 vacc3 = _mm_adds_epi16(vacc3, voutput_zero_point);
655
656 const __m128i vy0 = _mm_packus_epi16(vacc0, vacc1);
657 const __m128i vy1 = _mm_packus_epi16(vacc2, vacc3);
658
659 _mm_storeu_si128((__m128i*) y, vy0);
660 _mm_storeu_si128((__m128i*) (y + 16), vy1);
661 y += 32;
662 }
663 for (; n >= 16 * sizeof(uint8_t); n -= 16 * sizeof(uint8_t)) {
664 const __m128i vx = _mm_loadu_si128((const __m128i*) x);
665 x += 16;
666
667 __m128i vacc_lo = _mm_unpacklo_epi8(vx, vzero);
668 __m128i vacc_hi = _mm_unpackhi_epi8(vx, vzero);
669 __m128i vmultiplier_lo = _mm_cmpgt_epi16(vacc_lo, vinput_zero_point);
670 __m128i vmultiplier_hi = _mm_cmpgt_epi16(vacc_hi, vinput_zero_point);
671 vacc_lo = _mm_sub_epi16(vinput_zero_point, vacc_lo);
672 vacc_hi = _mm_sub_epi16(vinput_zero_point, vacc_hi);
673 vmultiplier_lo = _mm_and_si128(vmultiplier_lo, vmultiplier_diff);
674 vmultiplier_hi = _mm_and_si128(vmultiplier_hi, vmultiplier_diff);
675 vacc_lo = _mm_slli_epi16(vacc_lo, 7);
676 vacc_hi = _mm_slli_epi16(vacc_hi, 7);
677 vmultiplier_lo = _mm_xor_si128(vmultiplier_lo, vmultiplier_base);
678 vmultiplier_hi = _mm_xor_si128(vmultiplier_hi, vmultiplier_base);
679 vacc_lo = _mm_mulhrs_epi16(vacc_lo, vmultiplier_lo);
680 vacc_hi = _mm_mulhrs_epi16(vacc_hi, vmultiplier_hi);
681 vacc_lo = _mm_adds_epi16(vacc_lo, voutput_zero_point);
682 vacc_hi = _mm_adds_epi16(vacc_hi, voutput_zero_point);
683
684 const __m128i vy = _mm_packus_epi16(vacc_lo, vacc_hi);
685 _mm_storeu_si128((__m128i*) y, vy);
686 y += 16;
687 }
688 if XNN_UNLIKELY(n != 0) {
689 assert(n >= 1 * sizeof(uint8_t));
690 assert(n <= 15 * sizeof(uint8_t));
691
692 const __m128i vx = _mm_loadu_si128((const __m128i*) x);
693
694 __m128i vacc_lo = _mm_unpacklo_epi8(vx, vzero);
695 __m128i vacc_hi = _mm_unpackhi_epi8(vx, vzero);
696 __m128i vmultiplier_lo = _mm_cmpgt_epi16(vacc_lo, vinput_zero_point);
697 __m128i vmultiplier_hi = _mm_cmpgt_epi16(vacc_hi, vinput_zero_point);
698 vacc_lo = _mm_sub_epi16(vinput_zero_point, vacc_lo);
699 vacc_hi = _mm_sub_epi16(vinput_zero_point, vacc_hi);
700 vmultiplier_lo = _mm_and_si128(vmultiplier_lo, vmultiplier_diff);
701 vmultiplier_hi = _mm_and_si128(vmultiplier_hi, vmultiplier_diff);
702 vacc_lo = _mm_slli_epi16(vacc_lo, 7);
703 vacc_hi = _mm_slli_epi16(vacc_hi, 7);
704 vmultiplier_lo = _mm_xor_si128(vmultiplier_lo, vmultiplier_base);
705 vmultiplier_hi = _mm_xor_si128(vmultiplier_hi, vmultiplier_base);
706 vacc_lo = _mm_mulhrs_epi16(vacc_lo, vmultiplier_lo);
707 vacc_hi = _mm_mulhrs_epi16(vacc_hi, vmultiplier_hi);
708 vacc_lo = _mm_adds_epi16(vacc_lo, voutput_zero_point);
709 vacc_hi = _mm_adds_epi16(vacc_hi, voutput_zero_point);
710
711 __m128i vy = _mm_packus_epi16(vacc_lo, vacc_hi);
712 if (n & (8 * sizeof(uint8_t))) {
713 _mm_storel_epi64((__m128i*) y, vy);
714 vy = _mm_unpackhi_epi64(vy, vy);
715 y += 8;
716 }
717 if (n & (4 * sizeof(uint8_t))) {
718 unaligned_store_u32(y, (uint32_t) _mm_cvtsi128_si32(vy));
719 vy = _mm_srli_epi64(vy, 32);
720 y += 4;
721 }
722 uint32_t vy_lo = (uint32_t) _mm_cvtsi128_si32(vy);
723 if (n & (2 * sizeof(uint8_t))) {
724 unaligned_store_u16(y, (uint16_t) vy_lo);
725 vy_lo >>= 16;
726 y += 2;
727 }
728 if (n & (1 * sizeof(uint8_t))) {
729 *y = (uint8_t) vy_lo;
730 }
731 }
732 }
733