xref: /aosp_15_r20/external/XNNPACK/src/qu8-dwconv/gen/up8x9-minmax-fp32-wasmsimd-mul16.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Auto-generated file. Do not edit!
2 //   Template: src/qs8-dwconv/unipass-wasmsimd-mul16.c.in
3 //   Generator: tools/xngen
4 //
5 // Copyright 2020 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9 
10 #include <assert.h>
11 
12 #include <wasm_simd128.h>
13 
14 #include <xnnpack/dwconv.h>
15 
16 
xnn_qu8_dwconv_minmax_fp32_ukernel_up8x9__wasmsimd_mul16(size_t channels,size_t output_width,const uint8_t ** input,const void * weights,uint8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const uint8_t * zero,const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])17 void xnn_qu8_dwconv_minmax_fp32_ukernel_up8x9__wasmsimd_mul16(
18     size_t channels,
19     size_t output_width,
20     const uint8_t** input,
21     const void* weights,
22     uint8_t* output,
23     size_t input_stride,
24     size_t output_increment,
25     size_t input_offset,
26     const uint8_t* zero,
27     const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
28 {
29   assert(channels != 0);
30   assert(output_width != 0);
31 
32   const v128_t vkernel_zero_point = wasm_u32x4_load16x4(params->fp32_wasmsimd.kernel_zero_point);
33   do {
34     const uint8_t* i0 = input[0];
35     assert(i0 != NULL);
36     if XNN_UNPREDICTABLE(i0 != zero) {
37       i0 = (const uint8_t*) ((uintptr_t) i0 + input_offset);
38     }
39     const uint8_t* i1 = input[1];
40     assert(i1 != NULL);
41     if XNN_UNPREDICTABLE(i1 != zero) {
42       i1 = (const uint8_t*) ((uintptr_t) i1 + input_offset);
43     }
44     const uint8_t* i2 = input[2];
45     assert(i2 != NULL);
46     if XNN_UNPREDICTABLE(i2 != zero) {
47       i2 = (const uint8_t*) ((uintptr_t) i2 + input_offset);
48     }
49     const uint8_t* i3 = input[3];
50     assert(i3 != NULL);
51     if XNN_UNPREDICTABLE(i3 != zero) {
52       i3 = (const uint8_t*) ((uintptr_t) i3 + input_offset);
53     }
54     const uint8_t* i4 = input[4];
55     assert(i4 != NULL);
56     if XNN_UNPREDICTABLE(i4 != zero) {
57       i4 = (const uint8_t*) ((uintptr_t) i4 + input_offset);
58     }
59     const uint8_t* i5 = input[5];
60     assert(i5 != NULL);
61     if XNN_UNPREDICTABLE(i5 != zero) {
62       i5 = (const uint8_t*) ((uintptr_t) i5 + input_offset);
63     }
64     const uint8_t* i6 = input[6];
65     assert(i6 != NULL);
66     if XNN_UNPREDICTABLE(i6 != zero) {
67       i6 = (const uint8_t*) ((uintptr_t) i6 + input_offset);
68     }
69     const uint8_t* i7 = input[7];
70     assert(i7 != NULL);
71     if XNN_UNPREDICTABLE(i7 != zero) {
72       i7 = (const uint8_t*) ((uintptr_t) i7 + input_offset);
73     }
74     const uint8_t* i8 = input[8];
75     assert(i8 != NULL);
76     if XNN_UNPREDICTABLE(i8 != zero) {
77       i8 = (const uint8_t*) ((uintptr_t) i8 + input_offset);
78     }
79     input = (const uint8_t**) ((uintptr_t) input + input_stride);
80 
81     size_t c = channels;
82     const void* w = weights;
83     for (; c >= 8; c -= 8) {
84       v128_t vacc0123 = wasm_v128_load(w);
85       v128_t vacc4567 = wasm_v128_load((const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
86 
87 
88       const v128_t vi0x01234567 = wasm_u16x8_load8x8(i0);
89       const v128_t vk0x01234567 = wasm_u16x8_load8x8((const void*) ((uintptr_t) w + 8 * sizeof(int32_t) + 0 * sizeof(uint8_t)));
90       i0 += 8;
91 
92       v128_t vprod01234567 = wasm_i16x8_mul(vi0x01234567, vk0x01234567);
93 
94       vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vprod01234567));
95       vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vprod01234567));
96 
97       const v128_t vi1x01234567 = wasm_u16x8_load8x8(i1);
98       const v128_t vk1x01234567 = wasm_u16x8_load8x8((const void*) ((uintptr_t) w + 8 * sizeof(int32_t) + 8 * sizeof(uint8_t)));
99       v128_t vsumx01234567 = wasm_i16x8_add(vi0x01234567, vi1x01234567);
100       i1 += 8;
101 
102       vprod01234567 = wasm_i16x8_mul(vi1x01234567, vk1x01234567);
103 
104       vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vprod01234567));
105       vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vprod01234567));
106 
107       const v128_t vi2x01234567 = wasm_u16x8_load8x8(i2);
108       const v128_t vk2x01234567 = wasm_u16x8_load8x8((const void*) ((uintptr_t) w + 8 * sizeof(int32_t) + 16 * sizeof(uint8_t)));
109       vsumx01234567 = wasm_i16x8_add(vsumx01234567, vi2x01234567);
110       i2 += 8;
111 
112       vprod01234567 = wasm_i16x8_mul(vi2x01234567, vk2x01234567);
113 
114       vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vprod01234567));
115       vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vprod01234567));
116 
117       const v128_t vi3x01234567 = wasm_u16x8_load8x8(i3);
118       const v128_t vk3x01234567 = wasm_u16x8_load8x8((const void*) ((uintptr_t) w + 8 * sizeof(int32_t) + 24 * sizeof(uint8_t)));
119       vsumx01234567 = wasm_i16x8_add(vsumx01234567, vi3x01234567);
120       i3 += 8;
121 
122       vprod01234567 = wasm_i16x8_mul(vi3x01234567, vk3x01234567);
123 
124       vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vprod01234567));
125       vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vprod01234567));
126 
127       const v128_t vi4x01234567 = wasm_u16x8_load8x8(i4);
128       const v128_t vk4x01234567 = wasm_u16x8_load8x8((const void*) ((uintptr_t) w + 8 * sizeof(int32_t) + 32 * sizeof(uint8_t)));
129       vsumx01234567 = wasm_i16x8_add(vsumx01234567, vi4x01234567);
130       i4 += 8;
131 
132       vprod01234567 = wasm_i16x8_mul(vi4x01234567, vk4x01234567);
133 
134       vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vprod01234567));
135       vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vprod01234567));
136 
137       const v128_t vi5x01234567 = wasm_u16x8_load8x8(i5);
138       const v128_t vk5x01234567 = wasm_u16x8_load8x8((const void*) ((uintptr_t) w + 8 * sizeof(int32_t) + 40 * sizeof(uint8_t)));
139       vsumx01234567 = wasm_i16x8_add(vsumx01234567, vi5x01234567);
140       i5 += 8;
141 
142       vprod01234567 = wasm_i16x8_mul(vi5x01234567, vk5x01234567);
143 
144       vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vprod01234567));
145       vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vprod01234567));
146 
147       const v128_t vi6x01234567 = wasm_u16x8_load8x8(i6);
148       const v128_t vk6x01234567 = wasm_u16x8_load8x8((const void*) ((uintptr_t) w + 8 * sizeof(int32_t) + 48 * sizeof(uint8_t)));
149       vsumx01234567 = wasm_i16x8_add(vsumx01234567, vi6x01234567);
150       i6 += 8;
151 
152       vprod01234567 = wasm_i16x8_mul(vi6x01234567, vk6x01234567);
153 
154       vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vprod01234567));
155       vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vprod01234567));
156 
157       const v128_t vi7x01234567 = wasm_u16x8_load8x8(i7);
158       const v128_t vk7x01234567 = wasm_u16x8_load8x8((const void*) ((uintptr_t) w + 8 * sizeof(int32_t) + 56 * sizeof(uint8_t)));
159       vsumx01234567 = wasm_i16x8_add(vsumx01234567, vi7x01234567);
160       i7 += 8;
161 
162       vprod01234567 = wasm_i16x8_mul(vi7x01234567, vk7x01234567);
163 
164       vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vprod01234567));
165       vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vprod01234567));
166 
167       const v128_t vi8x01234567 = wasm_u16x8_load8x8(i8);
168       const v128_t vk8x01234567 = wasm_u16x8_load8x8((const void*) ((uintptr_t) w + 8 * sizeof(int32_t) + 64 * sizeof(uint8_t)));
169       vsumx01234567 = wasm_i16x8_add(vsumx01234567, vi8x01234567);
170       i8 += 8;
171 
172       vprod01234567 = wasm_i16x8_mul(vi8x01234567, vk8x01234567);
173 
174       vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vprod01234567));
175       vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vprod01234567));
176 
177       vacc0123 = wasm_i32x4_sub(vacc0123, wasm_i32x4_mul(wasm_u32x4_extend_low_u16x8(vsumx01234567), vkernel_zero_point));
178       vacc4567 = wasm_i32x4_sub(vacc4567, wasm_i32x4_mul(wasm_u32x4_extend_high_u16x8(vsumx01234567), vkernel_zero_point));
179 
180       w = (const void*) ((uintptr_t) w + 8 * sizeof(int32_t) + 72 * sizeof(uint8_t));
181 
182       vacc0123 = wasm_f32x4_convert_i32x4(vacc0123);
183       vacc4567 = wasm_f32x4_convert_i32x4(vacc4567);
184 
185       const v128_t vscale = wasm_v128_load64_splat(params->fp32_wasmsimd.scale);
186       vacc0123 = wasm_f32x4_mul(vacc0123, vscale);
187       vacc4567 = wasm_f32x4_mul(vacc4567, vscale);
188 
189       const v128_t vmagic_bias = wasm_v128_load64_splat(params->fp32_wasmsimd.magic_bias);
190       vacc0123 = wasm_f32x4_add(vacc0123, vmagic_bias);
191       vacc4567 = wasm_f32x4_add(vacc4567, vmagic_bias);
192 
193       const v128_t vmagic_min = wasm_v128_load64_splat(params->fp32_wasmsimd.magic_min);
194       vacc0123 = wasm_i32x4_max(vacc0123, vmagic_min);
195       vacc4567 = wasm_i32x4_max(vacc4567, vmagic_min);
196 
197       const v128_t vmagic_bias_less_output_zero_point = wasm_v128_load64_splat(params->fp32_wasmsimd.magic_bias_less_output_zero_point);
198       vacc0123 = wasm_i32x4_sub(vacc0123, vmagic_bias_less_output_zero_point);
199       vacc4567 = wasm_i32x4_sub(vacc4567, vmagic_bias_less_output_zero_point);
200 
201       v128_t vout01234567 = wasm_i16x8_narrow_i32x4(vacc0123, vacc4567);
202 
203       v128_t vout0123456701234567 = wasm_u8x16_narrow_i16x8(vout01234567, vout01234567);
204 
205       const v128_t voutput_max = wasm_v128_load64_splat(params->fp32_wasmsimd.output_max);
206       vout0123456701234567 = wasm_u8x16_min(vout0123456701234567, voutput_max);
207 
208       *((double*) output) = wasm_f64x2_extract_lane(vout0123456701234567, 0);
209       output += 8;
210     }
211     if XNN_UNLIKELY(c != 0) {
212       {
213         v128_t vacc0123 = wasm_v128_load(w);
214         v128_t vacc4567 = wasm_v128_load((const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
215 
216 
217         const v128_t vi0x01234567 = wasm_u16x8_load8x8(i0);
218         const v128_t vk0x01234567 = wasm_u16x8_load8x8((const void*) ((uintptr_t) w + 8 * sizeof(int32_t) + 0 * sizeof(uint8_t)));
219 
220         v128_t vprod01234567 = wasm_i16x8_mul(vi0x01234567, vk0x01234567);
221 
222         vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vprod01234567));
223         vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vprod01234567));
224 
225         const v128_t vi1x01234567 = wasm_u16x8_load8x8(i1);
226         const v128_t vk1x01234567 = wasm_u16x8_load8x8((const void*) ((uintptr_t) w + 8 * sizeof(int32_t) + 8 * sizeof(uint8_t)));
227         v128_t vsumx01234567 = wasm_i16x8_add(vi0x01234567, vi1x01234567);
228 
229         vprod01234567 = wasm_i16x8_mul(vi1x01234567, vk1x01234567);
230 
231         vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vprod01234567));
232         vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vprod01234567));
233 
234         const v128_t vi2x01234567 = wasm_u16x8_load8x8(i2);
235         const v128_t vk2x01234567 = wasm_u16x8_load8x8((const void*) ((uintptr_t) w + 8 * sizeof(int32_t) + 16 * sizeof(uint8_t)));
236         vsumx01234567 = wasm_i16x8_add(vsumx01234567, vi2x01234567);
237 
238         vprod01234567 = wasm_i16x8_mul(vi2x01234567, vk2x01234567);
239 
240         vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vprod01234567));
241         vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vprod01234567));
242 
243         const v128_t vi3x01234567 = wasm_u16x8_load8x8(i3);
244         const v128_t vk3x01234567 = wasm_u16x8_load8x8((const void*) ((uintptr_t) w + 8 * sizeof(int32_t) + 24 * sizeof(uint8_t)));
245         vsumx01234567 = wasm_i16x8_add(vsumx01234567, vi3x01234567);
246 
247         vprod01234567 = wasm_i16x8_mul(vi3x01234567, vk3x01234567);
248 
249         vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vprod01234567));
250         vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vprod01234567));
251 
252         const v128_t vi4x01234567 = wasm_u16x8_load8x8(i4);
253         const v128_t vk4x01234567 = wasm_u16x8_load8x8((const void*) ((uintptr_t) w + 8 * sizeof(int32_t) + 32 * sizeof(uint8_t)));
254         vsumx01234567 = wasm_i16x8_add(vsumx01234567, vi4x01234567);
255 
256         vprod01234567 = wasm_i16x8_mul(vi4x01234567, vk4x01234567);
257 
258         vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vprod01234567));
259         vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vprod01234567));
260 
261         const v128_t vi5x01234567 = wasm_u16x8_load8x8(i5);
262         const v128_t vk5x01234567 = wasm_u16x8_load8x8((const void*) ((uintptr_t) w + 8 * sizeof(int32_t) + 40 * sizeof(uint8_t)));
263         vsumx01234567 = wasm_i16x8_add(vsumx01234567, vi5x01234567);
264 
265         vprod01234567 = wasm_i16x8_mul(vi5x01234567, vk5x01234567);
266 
267         vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vprod01234567));
268         vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vprod01234567));
269 
270         const v128_t vi6x01234567 = wasm_u16x8_load8x8(i6);
271         const v128_t vk6x01234567 = wasm_u16x8_load8x8((const void*) ((uintptr_t) w + 8 * sizeof(int32_t) + 48 * sizeof(uint8_t)));
272         vsumx01234567 = wasm_i16x8_add(vsumx01234567, vi6x01234567);
273 
274         vprod01234567 = wasm_i16x8_mul(vi6x01234567, vk6x01234567);
275 
276         vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vprod01234567));
277         vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vprod01234567));
278 
279         const v128_t vi7x01234567 = wasm_u16x8_load8x8(i7);
280         const v128_t vk7x01234567 = wasm_u16x8_load8x8((const void*) ((uintptr_t) w + 8 * sizeof(int32_t) + 56 * sizeof(uint8_t)));
281         vsumx01234567 = wasm_i16x8_add(vsumx01234567, vi7x01234567);
282 
283         vprod01234567 = wasm_i16x8_mul(vi7x01234567, vk7x01234567);
284 
285         vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vprod01234567));
286         vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vprod01234567));
287 
288         const v128_t vi8x01234567 = wasm_u16x8_load8x8(i8);
289         const v128_t vk8x01234567 = wasm_u16x8_load8x8((const void*) ((uintptr_t) w + 8 * sizeof(int32_t) + 64 * sizeof(uint8_t)));
290         vsumx01234567 = wasm_i16x8_add(vsumx01234567, vi8x01234567);
291 
292         vprod01234567 = wasm_i16x8_mul(vi8x01234567, vk8x01234567);
293 
294         vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vprod01234567));
295         vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vprod01234567));
296 
297 
298       vacc0123 = wasm_i32x4_sub(vacc0123, wasm_i32x4_mul(wasm_u32x4_extend_low_u16x8(vsumx01234567), vkernel_zero_point));
299       vacc4567 = wasm_i32x4_sub(vacc4567, wasm_i32x4_mul(wasm_u32x4_extend_high_u16x8(vsumx01234567), vkernel_zero_point));
300 
301       vacc0123 = wasm_f32x4_convert_i32x4(vacc0123);
302       vacc4567 = wasm_f32x4_convert_i32x4(vacc4567);
303 
304       const v128_t vscale = wasm_v128_load64_splat(params->fp32_wasmsimd.scale);
305       vacc0123 = wasm_f32x4_mul(vacc0123, vscale);
306       vacc4567 = wasm_f32x4_mul(vacc4567, vscale);
307 
308       const v128_t vmagic_bias = wasm_v128_load64_splat(params->fp32_wasmsimd.magic_bias);
309       vacc0123 = wasm_f32x4_add(vacc0123, vmagic_bias);
310       vacc4567 = wasm_f32x4_add(vacc4567, vmagic_bias);
311 
312       const v128_t vmagic_min = wasm_v128_load64_splat(params->fp32_wasmsimd.magic_min);
313       vacc0123 = wasm_i32x4_max(vacc0123, vmagic_min);
314       vacc4567 = wasm_i32x4_max(vacc4567, vmagic_min);
315 
316       const v128_t vmagic_bias_less_output_zero_point = wasm_v128_load64_splat(params->fp32_wasmsimd.magic_bias_less_output_zero_point);
317       vacc0123 = wasm_i32x4_sub(vacc0123, vmagic_bias_less_output_zero_point);
318       vacc4567 = wasm_i32x4_sub(vacc4567, vmagic_bias_less_output_zero_point);
319 
320       v128_t vout01234567 = wasm_i16x8_narrow_i32x4(vacc0123, vacc4567);
321       v128_t vout0123456701234567 = wasm_u8x16_narrow_i16x8(vout01234567, vout01234567);
322 
323       const v128_t voutput_max = wasm_v128_load64_splat(params->fp32_wasmsimd.output_max);
324       vout0123456701234567 = wasm_u8x16_min(vout0123456701234567, voutput_max);
325 
326 
327       if (c & 4) {
328         *((float*) output) = wasm_f32x4_extract_lane(vout0123456701234567, 0);
329         vout0123456701234567 = wasm_u64x2_shr(vout0123456701234567, 32);
330         output += 4;
331       }
332       uint32_t vout0123 = wasm_i32x4_extract_lane(vout0123456701234567, 0);
333       if (c & 2) {
334         *((uint16_t*) output) = (uint16_t) vout0123;
335         vout0123 >>= 16;
336         output += 2;
337       }
338       if (c & 1) {
339         *output = (uint8_t) vout0123;
340         output += 1;
341       }
342       }
343     }
344 
345     output = (uint8_t*) ((uintptr_t) output + output_increment);
346   } while (--output_width != 0);
347 }
348