1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10
11 #include <emmintrin.h>
12
13 #include <qnnpack/q8avgpool.h>
14
pytorch_q8avgpool_ukernel_up8x9__sse2(size_t n,size_t ks,size_t kc,const uint8_t ** input,const uint8_t * zero,uint8_t * output,size_t input_increment,size_t output_increment,const union pytorch_qnnp_avgpool_quantization_params quantization_params[RESTRICT_STATIC1])15 void pytorch_q8avgpool_ukernel_up8x9__sse2(
16 size_t n,
17 size_t ks,
18 size_t kc,
19 const uint8_t** input,
20 const uint8_t* zero,
21 uint8_t* output,
22 size_t input_increment,
23 size_t output_increment,
24 const union pytorch_qnnp_avgpool_quantization_params
25 quantization_params[RESTRICT_STATIC 1]) {
26 assert(n != 0);
27 assert(ks <= 9);
28 assert(kc >= 8);
29
30 const __m128i vbias =
31 _mm_load_si128((const __m128i*)&quantization_params->sse2.bias);
32 const __m128i vzero = _mm_setzero_si128();
33 const __m128 vscale = _mm_loadu_ps(quantization_params->sse2.scale);
34
35 do {
36 const uint8_t* i0 = input[0];
37 const uint8_t* i1 = input[1];
38 const uint8_t* i2 = input[2];
39 const uint8_t* i3 = input[3];
40 const uint8_t* i4 = input[4];
41 const uint8_t* i5 = input[5];
42 const uint8_t* i6 = input[6];
43 const uint8_t* i7 = input[7];
44 const uint8_t* i8 = input[8];
45 input = (const uint8_t**)((uintptr_t)input + input_increment);
46 if (ks < 2) {
47 i1 = zero;
48 }
49 if (ks <= 2) {
50 i2 = zero;
51 }
52 if (ks < 4) {
53 i3 = zero;
54 }
55 if (ks <= 4) {
56 i4 = zero;
57 }
58 if (ks < 6) {
59 i5 = zero;
60 }
61 if (ks <= 6) {
62 i6 = zero;
63 }
64 if (ks < 8) {
65 i7 = zero;
66 }
67 if (ks <= 8) {
68 i8 = zero;
69 }
70
71 size_t k = kc;
72 while (k >= 8) {
73 const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0);
74 i0 += 8;
75 const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1);
76 i1 += 8;
77 const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2);
78 i2 += 8;
79 const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3);
80 i3 += 8;
81 const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4);
82 i4 += 8;
83 const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5);
84 i5 += 8;
85 const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6);
86 i6 += 8;
87 const __m128i vi7 = _mm_loadl_epi64((const __m128i*)i7);
88 i7 += 8;
89 const __m128i vi8 = _mm_loadl_epi64((const __m128i*)i8);
90 i8 += 8;
91
92 const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
93 const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
94 const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
95 const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
96 const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
97 const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
98 const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
99 const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero);
100 const __m128i vxi8 = _mm_unpacklo_epi8(vi8, vzero);
101
102 const __m128i vsum018 = _mm_add_epi16(_mm_add_epi16(vxi0, vxi1), vxi8);
103 const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
104 const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
105 const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7);
106
107 const __m128i vsum2345 = _mm_add_epi16(vsum23, vsum45);
108 const __m128i vsum01678 = _mm_add_epi16(vsum018, vsum67);
109 const __m128i vsum = _mm_add_epi16(vsum2345, vsum01678);
110
111 const __m128i vacc_lo =
112 _mm_add_epi32(vbias, _mm_unpacklo_epi16(vsum, vzero));
113 const __m128i vacc_hi =
114 _mm_add_epi32(vbias, _mm_unpackhi_epi16(vsum, vzero));
115
116 const __m128 vacc_lo_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_lo), vscale);
117 const __m128 vacc_hi_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_hi), vscale);
118
119 const __m128i vscaled_lo = _mm_cvtps_epi32(vacc_lo_f);
120 const __m128i vscaled_hi = _mm_cvtps_epi32(vacc_hi_f);
121
122 __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
123 vout = _mm_adds_epi16(
124 vout,
125 _mm_load_si128(
126 (const __m128i*)&quantization_params->sse2.output_zero_point));
127 vout = _mm_packus_epi16(vout, vout);
128 vout = _mm_min_epu8(
129 vout,
130 _mm_load_si128(
131 (const __m128i*)&quantization_params->sse2.output_max));
132 vout = _mm_max_epu8(
133 vout,
134 _mm_load_si128(
135 (const __m128i*)&quantization_params->sse2.output_min));
136
137 _mm_storel_epi64((__m128i*)output, vout);
138 output += 8;
139
140 k -= 8;
141 }
142 if (k != 0) {
143 const size_t address_decrement = 8 - k;
144 i0 = (const uint8_t*)((uintptr_t)i0 - address_decrement);
145 i1 = (const uint8_t*)((uintptr_t)i1 - address_decrement);
146 i2 = (const uint8_t*)((uintptr_t)i2 - address_decrement);
147 i3 = (const uint8_t*)((uintptr_t)i3 - address_decrement);
148 i4 = (const uint8_t*)((uintptr_t)i4 - address_decrement);
149 i5 = (const uint8_t*)((uintptr_t)i5 - address_decrement);
150 i6 = (const uint8_t*)((uintptr_t)i6 - address_decrement);
151 i7 = (const uint8_t*)((uintptr_t)i7 - address_decrement);
152 i8 = (const uint8_t*)((uintptr_t)i8 - address_decrement);
153 const __m128i vshift = _mm_cvtsi32_si128(8 * address_decrement);
154
155 const __m128i vi0 =
156 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i0), vshift);
157 const __m128i vi1 =
158 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i1), vshift);
159 const __m128i vi2 =
160 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i2), vshift);
161 const __m128i vi3 =
162 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i3), vshift);
163 const __m128i vi4 =
164 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i4), vshift);
165 const __m128i vi5 =
166 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i5), vshift);
167 const __m128i vi6 =
168 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i6), vshift);
169 const __m128i vi7 =
170 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i7), vshift);
171 const __m128i vi8 =
172 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i8), vshift);
173
174 const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
175 const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
176 const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
177 const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
178 const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
179 const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
180 const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
181 const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero);
182 const __m128i vxi8 = _mm_unpacklo_epi8(vi8, vzero);
183
184 const __m128i vsum018 = _mm_add_epi16(_mm_add_epi16(vxi0, vxi1), vxi8);
185 const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
186 const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
187 const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7);
188
189 const __m128i vsum2345 = _mm_add_epi16(vsum23, vsum45);
190 const __m128i vsum01678 = _mm_add_epi16(vsum018, vsum67);
191 const __m128i vsum = _mm_add_epi16(vsum2345, vsum01678);
192
193 const __m128i vacc_lo =
194 _mm_add_epi32(vbias, _mm_unpacklo_epi16(vsum, vzero));
195 const __m128i vacc_hi =
196 _mm_add_epi32(vbias, _mm_unpackhi_epi16(vsum, vzero));
197
198 const __m128 vacc_lo_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_lo), vscale);
199 const __m128 vacc_hi_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_hi), vscale);
200
201 const __m128i vscaled_lo = _mm_cvtps_epi32(vacc_lo_f);
202 const __m128i vscaled_hi = _mm_cvtps_epi32(vacc_hi_f);
203
204 __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
205 vout = _mm_adds_epi16(
206 vout,
207 _mm_load_si128(
208 (const __m128i*)&quantization_params->sse2.output_zero_point));
209 vout = _mm_packus_epi16(vout, vout);
210 vout = _mm_min_epu8(
211 vout,
212 _mm_load_si128(
213 (const __m128i*)&quantization_params->sse2.output_max));
214 vout = _mm_max_epu8(
215 vout,
216 _mm_load_si128(
217 (const __m128i*)&quantization_params->sse2.output_min));
218
219 if (k & 4) {
220 *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout);
221 output += 4;
222 vout = _mm_srli_epi64(vout, 32);
223 }
224 if (k & 2) {
225 *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0);
226 output += 2;
227 vout = _mm_srli_epi32(vout, 16);
228 }
229 if (k & 1) {
230 *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout);
231 output += 1;
232 }
233 }
234 output = (uint8_t*)((uintptr_t)output + output_increment);
235 } while (--n != 0);
236 }
237