xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/up8x9-sse2.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 
11 #include <emmintrin.h>
12 
13 #include <qnnpack/q8avgpool.h>
14 
pytorch_q8avgpool_ukernel_up8x9__sse2(size_t n,size_t ks,size_t kc,const uint8_t ** input,const uint8_t * zero,uint8_t * output,size_t input_increment,size_t output_increment,const union pytorch_qnnp_avgpool_quantization_params quantization_params[RESTRICT_STATIC1])15 void pytorch_q8avgpool_ukernel_up8x9__sse2(
16     size_t n,
17     size_t ks,
18     size_t kc,
19     const uint8_t** input,
20     const uint8_t* zero,
21     uint8_t* output,
22     size_t input_increment,
23     size_t output_increment,
24     const union pytorch_qnnp_avgpool_quantization_params
25         quantization_params[RESTRICT_STATIC 1]) {
26   assert(n != 0);
27   assert(ks <= 9);
28   assert(kc >= 8);
29 
30   const __m128i vbias =
31       _mm_load_si128((const __m128i*)&quantization_params->sse2.bias);
32   const __m128i vzero = _mm_setzero_si128();
33   const __m128 vscale = _mm_loadu_ps(quantization_params->sse2.scale);
34 
35   do {
36     const uint8_t* i0 = input[0];
37     const uint8_t* i1 = input[1];
38     const uint8_t* i2 = input[2];
39     const uint8_t* i3 = input[3];
40     const uint8_t* i4 = input[4];
41     const uint8_t* i5 = input[5];
42     const uint8_t* i6 = input[6];
43     const uint8_t* i7 = input[7];
44     const uint8_t* i8 = input[8];
45     input = (const uint8_t**)((uintptr_t)input + input_increment);
46     if (ks < 2) {
47       i1 = zero;
48     }
49     if (ks <= 2) {
50       i2 = zero;
51     }
52     if (ks < 4) {
53       i3 = zero;
54     }
55     if (ks <= 4) {
56       i4 = zero;
57     }
58     if (ks < 6) {
59       i5 = zero;
60     }
61     if (ks <= 6) {
62       i6 = zero;
63     }
64     if (ks < 8) {
65       i7 = zero;
66     }
67     if (ks <= 8) {
68       i8 = zero;
69     }
70 
71     size_t k = kc;
72     while (k >= 8) {
73       const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0);
74       i0 += 8;
75       const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1);
76       i1 += 8;
77       const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2);
78       i2 += 8;
79       const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3);
80       i3 += 8;
81       const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4);
82       i4 += 8;
83       const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5);
84       i5 += 8;
85       const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6);
86       i6 += 8;
87       const __m128i vi7 = _mm_loadl_epi64((const __m128i*)i7);
88       i7 += 8;
89       const __m128i vi8 = _mm_loadl_epi64((const __m128i*)i8);
90       i8 += 8;
91 
92       const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
93       const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
94       const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
95       const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
96       const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
97       const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
98       const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
99       const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero);
100       const __m128i vxi8 = _mm_unpacklo_epi8(vi8, vzero);
101 
102       const __m128i vsum018 = _mm_add_epi16(_mm_add_epi16(vxi0, vxi1), vxi8);
103       const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
104       const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
105       const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7);
106 
107       const __m128i vsum2345 = _mm_add_epi16(vsum23, vsum45);
108       const __m128i vsum01678 = _mm_add_epi16(vsum018, vsum67);
109       const __m128i vsum = _mm_add_epi16(vsum2345, vsum01678);
110 
111       const __m128i vacc_lo =
112           _mm_add_epi32(vbias, _mm_unpacklo_epi16(vsum, vzero));
113       const __m128i vacc_hi =
114           _mm_add_epi32(vbias, _mm_unpackhi_epi16(vsum, vzero));
115 
116       const __m128 vacc_lo_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_lo), vscale);
117       const __m128 vacc_hi_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_hi), vscale);
118 
119       const __m128i vscaled_lo = _mm_cvtps_epi32(vacc_lo_f);
120       const __m128i vscaled_hi = _mm_cvtps_epi32(vacc_hi_f);
121 
122       __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
123       vout = _mm_adds_epi16(
124           vout,
125           _mm_load_si128(
126               (const __m128i*)&quantization_params->sse2.output_zero_point));
127       vout = _mm_packus_epi16(vout, vout);
128       vout = _mm_min_epu8(
129           vout,
130           _mm_load_si128(
131               (const __m128i*)&quantization_params->sse2.output_max));
132       vout = _mm_max_epu8(
133           vout,
134           _mm_load_si128(
135               (const __m128i*)&quantization_params->sse2.output_min));
136 
137       _mm_storel_epi64((__m128i*)output, vout);
138       output += 8;
139 
140       k -= 8;
141     }
142     if (k != 0) {
143       const size_t address_decrement = 8 - k;
144       i0 = (const uint8_t*)((uintptr_t)i0 - address_decrement);
145       i1 = (const uint8_t*)((uintptr_t)i1 - address_decrement);
146       i2 = (const uint8_t*)((uintptr_t)i2 - address_decrement);
147       i3 = (const uint8_t*)((uintptr_t)i3 - address_decrement);
148       i4 = (const uint8_t*)((uintptr_t)i4 - address_decrement);
149       i5 = (const uint8_t*)((uintptr_t)i5 - address_decrement);
150       i6 = (const uint8_t*)((uintptr_t)i6 - address_decrement);
151       i7 = (const uint8_t*)((uintptr_t)i7 - address_decrement);
152       i8 = (const uint8_t*)((uintptr_t)i8 - address_decrement);
153       const __m128i vshift = _mm_cvtsi32_si128(8 * address_decrement);
154 
155       const __m128i vi0 =
156           _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i0), vshift);
157       const __m128i vi1 =
158           _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i1), vshift);
159       const __m128i vi2 =
160           _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i2), vshift);
161       const __m128i vi3 =
162           _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i3), vshift);
163       const __m128i vi4 =
164           _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i4), vshift);
165       const __m128i vi5 =
166           _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i5), vshift);
167       const __m128i vi6 =
168           _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i6), vshift);
169       const __m128i vi7 =
170           _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i7), vshift);
171       const __m128i vi8 =
172           _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i8), vshift);
173 
174       const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
175       const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
176       const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
177       const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
178       const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
179       const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
180       const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
181       const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero);
182       const __m128i vxi8 = _mm_unpacklo_epi8(vi8, vzero);
183 
184       const __m128i vsum018 = _mm_add_epi16(_mm_add_epi16(vxi0, vxi1), vxi8);
185       const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
186       const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
187       const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7);
188 
189       const __m128i vsum2345 = _mm_add_epi16(vsum23, vsum45);
190       const __m128i vsum01678 = _mm_add_epi16(vsum018, vsum67);
191       const __m128i vsum = _mm_add_epi16(vsum2345, vsum01678);
192 
193       const __m128i vacc_lo =
194           _mm_add_epi32(vbias, _mm_unpacklo_epi16(vsum, vzero));
195       const __m128i vacc_hi =
196           _mm_add_epi32(vbias, _mm_unpackhi_epi16(vsum, vzero));
197 
198       const __m128 vacc_lo_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_lo), vscale);
199       const __m128 vacc_hi_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_hi), vscale);
200 
201       const __m128i vscaled_lo = _mm_cvtps_epi32(vacc_lo_f);
202       const __m128i vscaled_hi = _mm_cvtps_epi32(vacc_hi_f);
203 
204       __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
205       vout = _mm_adds_epi16(
206           vout,
207           _mm_load_si128(
208               (const __m128i*)&quantization_params->sse2.output_zero_point));
209       vout = _mm_packus_epi16(vout, vout);
210       vout = _mm_min_epu8(
211           vout,
212           _mm_load_si128(
213               (const __m128i*)&quantization_params->sse2.output_max));
214       vout = _mm_max_epu8(
215           vout,
216           _mm_load_si128(
217               (const __m128i*)&quantization_params->sse2.output_min));
218 
219       if (k & 4) {
220         *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout);
221         output += 4;
222         vout = _mm_srli_epi64(vout, 32);
223       }
224       if (k & 2) {
225         *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0);
226         output += 2;
227         vout = _mm_srli_epi32(vout, 16);
228       }
229       if (k & 1) {
230         *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout);
231         output += 1;
232       }
233     }
234     output = (uint8_t*)((uintptr_t)output + output_increment);
235   } while (--n != 0);
236 }
237