xref: /aosp_15_r20/external/XNNPACK/src/amalgam/avx512f.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2021 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 
8 #include <immintrin.h>
9 
10 #include <xnnpack/common.h>
11 #include <xnnpack/dwconv.h>
12 #include <xnnpack/gemm.h>
13 #include <xnnpack/igemm.h>
14 #include <xnnpack/intrinsics-polyfill.h>
15 #include <xnnpack/math.h>
16 #include <xnnpack/prelu.h>
17 #include <xnnpack/vbinary.h>
18 #include <xnnpack/vunary.h>
19 
20 
xnn_f32_dwconv_minmax_ukernel_up16x25__avx512f(size_t channels,size_t output_width,const float ** input,const float * weights,float * output,size_t input_stride,size_t output_increment,size_t input_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])21 void xnn_f32_dwconv_minmax_ukernel_up16x25__avx512f(
22     size_t channels,
23     size_t output_width,
24     const float** input,
25     const float* weights,
26     float* output,
27     size_t input_stride,
28     size_t output_increment,
29     size_t input_offset,
30     const float* zero,
31     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
32 {
33   assert(channels != 0);
34   assert(output_width != 0);
35 
36   const __m512 vmax = _mm512_set1_ps(params->scalar.max);
37   const __m512 vmin = _mm512_set1_ps(params->scalar.min);
38   do {
39     const float* i0 = input[0];
40     assert(i0 != NULL);
41     if XNN_UNPREDICTABLE(i0 != zero) {
42       i0 = (const float*) ((uintptr_t) i0 + input_offset);
43     }
44     const float* i1 = input[1];
45     assert(i1 != NULL);
46     if XNN_UNPREDICTABLE(i1 != zero) {
47       i1 = (const float*) ((uintptr_t) i1 + input_offset);
48     }
49     const float* i2 = input[2];
50     assert(i2 != NULL);
51     if XNN_UNPREDICTABLE(i2 != zero) {
52       i2 = (const float*) ((uintptr_t) i2 + input_offset);
53     }
54     const float* i3 = input[3];
55     assert(i3 != NULL);
56     if XNN_UNPREDICTABLE(i3 != zero) {
57       i3 = (const float*) ((uintptr_t) i3 + input_offset);
58     }
59     const float* i4 = input[4];
60     assert(i4 != NULL);
61     if XNN_UNPREDICTABLE(i4 != zero) {
62       i4 = (const float*) ((uintptr_t) i4 + input_offset);
63     }
64     const float* i5 = input[5];
65     assert(i5 != NULL);
66     if XNN_UNPREDICTABLE(i5 != zero) {
67       i5 = (const float*) ((uintptr_t) i5 + input_offset);
68     }
69     const float* i6 = input[6];
70     assert(i6 != NULL);
71     if XNN_UNPREDICTABLE(i6 != zero) {
72       i6 = (const float*) ((uintptr_t) i6 + input_offset);
73     }
74     const float* i7 = input[7];
75     assert(i7 != NULL);
76     if XNN_UNPREDICTABLE(i7 != zero) {
77       i7 = (const float*) ((uintptr_t) i7 + input_offset);
78     }
79     const float* i8 = input[8];
80     assert(i8 != NULL);
81     if XNN_UNPREDICTABLE(i8 != zero) {
82       i8 = (const float*) ((uintptr_t) i8 + input_offset);
83     }
84     const float* i9 = input[9];
85     assert(i9 != NULL);
86     if XNN_UNPREDICTABLE(i9 != zero) {
87       i9 = (const float*) ((uintptr_t) i9 + input_offset);
88     }
89     const float* i10 = input[10];
90     assert(i10 != NULL);
91     if XNN_UNPREDICTABLE(i10 != zero) {
92       i10 = (const float*) ((uintptr_t) i10 + input_offset);
93     }
94     const float* i11 = input[11];
95     assert(i11 != NULL);
96     if XNN_UNPREDICTABLE(i11 != zero) {
97       i11 = (const float*) ((uintptr_t) i11 + input_offset);
98     }
99     const float* i12 = input[12];
100     assert(i12 != NULL);
101     if XNN_UNPREDICTABLE(i12 != zero) {
102       i12 = (const float*) ((uintptr_t) i12 + input_offset);
103     }
104     const float* i13 = input[13];
105     assert(i13 != NULL);
106     if XNN_UNPREDICTABLE(i13 != zero) {
107       i13 = (const float*) ((uintptr_t) i13 + input_offset);
108     }
109     const float* i14 = input[14];
110     assert(i14 != NULL);
111     if XNN_UNPREDICTABLE(i14 != zero) {
112       i14 = (const float*) ((uintptr_t) i14 + input_offset);
113     }
114     const float* i15 = input[15];
115     assert(i15 != NULL);
116     if XNN_UNPREDICTABLE(i15 != zero) {
117       i15 = (const float*) ((uintptr_t) i15 + input_offset);
118     }
119     const float* i16 = input[16];
120     assert(i16 != NULL);
121     if XNN_UNPREDICTABLE(i16 != zero) {
122       i16 = (const float*) ((uintptr_t) i16 + input_offset);
123     }
124     const float* i17 = input[17];
125     assert(i17 != NULL);
126     if XNN_UNPREDICTABLE(i17 != zero) {
127       i17 = (const float*) ((uintptr_t) i17 + input_offset);
128     }
129     const float* i18 = input[18];
130     assert(i18 != NULL);
131     if XNN_UNPREDICTABLE(i18 != zero) {
132       i18 = (const float*) ((uintptr_t) i18 + input_offset);
133     }
134     const float* i19 = input[19];
135     assert(i19 != NULL);
136     if XNN_UNPREDICTABLE(i19 != zero) {
137       i19 = (const float*) ((uintptr_t) i19 + input_offset);
138     }
139     const float* i20 = input[20];
140     assert(i20 != NULL);
141     if XNN_UNPREDICTABLE(i20 != zero) {
142       i20 = (const float*) ((uintptr_t) i20 + input_offset);
143     }
144     const float* i21 = input[21];
145     assert(i21 != NULL);
146     if XNN_UNPREDICTABLE(i21 != zero) {
147       i21 = (const float*) ((uintptr_t) i21 + input_offset);
148     }
149     const float* i22 = input[22];
150     assert(i22 != NULL);
151     if XNN_UNPREDICTABLE(i22 != zero) {
152       i22 = (const float*) ((uintptr_t) i22 + input_offset);
153     }
154     const float* i23 = input[23];
155     assert(i23 != NULL);
156     if XNN_UNPREDICTABLE(i23 != zero) {
157       i23 = (const float*) ((uintptr_t) i23 + input_offset);
158     }
159     const float* i24 = input[24];
160     assert(i24 != NULL);
161     if XNN_UNPREDICTABLE(i24 != zero) {
162       i24 = (const float*) ((uintptr_t) i24 + input_offset);
163     }
164     input = (const float**) ((uintptr_t) input + input_stride);
165 
166     size_t c = channels;
167     const float* w = weights;
168     for (; c >= 16; c -= 16) {
169       __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
170 
171 
172       const __m512 vi0x0123456789ABCDEF = _mm512_loadu_ps(i0);
173       i0 += 16;
174 
175       const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16);
176       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
177 
178       const __m512 vi1x0123456789ABCDEF = _mm512_loadu_ps(i1);
179       i1 += 16;
180 
181       const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32);
182       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
183 
184       const __m512 vi2x0123456789ABCDEF = _mm512_loadu_ps(i2);
185       i2 += 16;
186 
187       const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48);
188       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
189 
190       const __m512 vi3x0123456789ABCDEF = _mm512_loadu_ps(i3);
191       i3 += 16;
192 
193       const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 64);
194       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0);
195 
196       const __m512 vi4x0123456789ABCDEF = _mm512_loadu_ps(i4);
197       i4 += 16;
198 
199       const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 80);
200       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0);
201 
202       const __m512 vi5x0123456789ABCDEF = _mm512_loadu_ps(i5);
203       i5 += 16;
204 
205       const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 96);
206       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp0);
207 
208       const __m512 vi6x0123456789ABCDEF = _mm512_loadu_ps(i6);
209       i6 += 16;
210 
211       const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 112);
212       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0);
213 
214       const __m512 vi7x0123456789ABCDEF = _mm512_loadu_ps(i7);
215       i7 += 16;
216 
217       const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 128);
218       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp0);
219 
220       const __m512 vi8x0123456789ABCDEF = _mm512_loadu_ps(i8);
221       i8 += 16;
222 
223       const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 144);
224       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0);
225 
226       const __m512 vi9x0123456789ABCDEF = _mm512_loadu_ps(i9);
227       i9 += 16;
228 
229       const __m512 vk9x0123456789ABCDEF = _mm512_load_ps(w + 160);
230       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF, vacc0123456789ABCDEFp0);
231 
232       const __m512 vi10x0123456789ABCDEF = _mm512_loadu_ps(i10);
233       i10 += 16;
234 
235       const __m512 vk10x0123456789ABCDEF = _mm512_load_ps(w + 176);
236       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF, vacc0123456789ABCDEFp0);
237 
238       const __m512 vi11x0123456789ABCDEF = _mm512_loadu_ps(i11);
239       i11 += 16;
240 
241       const __m512 vk11x0123456789ABCDEF = _mm512_load_ps(w + 192);
242       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF, vacc0123456789ABCDEFp0);
243 
244       const __m512 vi12x0123456789ABCDEF = _mm512_loadu_ps(i12);
245       i12 += 16;
246 
247       const __m512 vk12x0123456789ABCDEF = _mm512_load_ps(w + 208);
248       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF, vacc0123456789ABCDEFp0);
249 
250       const __m512 vi13x0123456789ABCDEF = _mm512_loadu_ps(i13);
251       i13 += 16;
252 
253       const __m512 vk13x0123456789ABCDEF = _mm512_load_ps(w + 224);
254       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF, vacc0123456789ABCDEFp0);
255 
256       const __m512 vi14x0123456789ABCDEF = _mm512_loadu_ps(i14);
257       i14 += 16;
258 
259       const __m512 vk14x0123456789ABCDEF = _mm512_load_ps(w + 240);
260       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF, vacc0123456789ABCDEFp0);
261 
262       const __m512 vi15x0123456789ABCDEF = _mm512_loadu_ps(i15);
263       i15 += 16;
264 
265       const __m512 vk15x0123456789ABCDEF = _mm512_load_ps(w + 256);
266       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF, vacc0123456789ABCDEFp0);
267 
268       const __m512 vi16x0123456789ABCDEF = _mm512_loadu_ps(i16);
269       i16 += 16;
270 
271       const __m512 vk16x0123456789ABCDEF = _mm512_load_ps(w + 272);
272       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF, vacc0123456789ABCDEFp0);
273 
274       const __m512 vi17x0123456789ABCDEF = _mm512_loadu_ps(i17);
275       i17 += 16;
276 
277       const __m512 vk17x0123456789ABCDEF = _mm512_load_ps(w + 288);
278       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF, vacc0123456789ABCDEFp0);
279 
280       const __m512 vi18x0123456789ABCDEF = _mm512_loadu_ps(i18);
281       i18 += 16;
282 
283       const __m512 vk18x0123456789ABCDEF = _mm512_load_ps(w + 304);
284       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF, vacc0123456789ABCDEFp0);
285 
286       const __m512 vi19x0123456789ABCDEF = _mm512_loadu_ps(i19);
287       i19 += 16;
288 
289       const __m512 vk19x0123456789ABCDEF = _mm512_load_ps(w + 320);
290       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF, vacc0123456789ABCDEFp0);
291 
292       const __m512 vi20x0123456789ABCDEF = _mm512_loadu_ps(i20);
293       i20 += 16;
294 
295       const __m512 vk20x0123456789ABCDEF = _mm512_load_ps(w + 336);
296       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF, vacc0123456789ABCDEFp0);
297 
298       const __m512 vi21x0123456789ABCDEF = _mm512_loadu_ps(i21);
299       i21 += 16;
300 
301       const __m512 vk21x0123456789ABCDEF = _mm512_load_ps(w + 352);
302       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF, vacc0123456789ABCDEFp0);
303 
304       const __m512 vi22x0123456789ABCDEF = _mm512_loadu_ps(i22);
305       i22 += 16;
306 
307       const __m512 vk22x0123456789ABCDEF = _mm512_load_ps(w + 368);
308       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF, vacc0123456789ABCDEFp0);
309 
310       const __m512 vi23x0123456789ABCDEF = _mm512_loadu_ps(i23);
311       i23 += 16;
312 
313       const __m512 vk23x0123456789ABCDEF = _mm512_load_ps(w + 384);
314       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF, vacc0123456789ABCDEFp0);
315 
316       const __m512 vi24x0123456789ABCDEF = _mm512_loadu_ps(i24);
317       i24 += 16;
318 
319       const __m512 vk24x0123456789ABCDEF = _mm512_load_ps(w + 400);
320       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF, vacc0123456789ABCDEFp0);
321 
322       w += 416;
323 
324 
325       __m512 vacc0123456789ABCDEF = _mm512_max_ps(vacc0123456789ABCDEFp0, vmin);
326       vacc0123456789ABCDEF = _mm512_min_ps(vacc0123456789ABCDEF, vmax);
327 
328       _mm512_storeu_ps(output, vacc0123456789ABCDEF);
329       output += 16;
330     }
331     if XNN_UNLIKELY(c != 0) {
332       assert(c >= 1);
333       assert(c <= 16);
334       // Prepare mask for valid 32-bit elements (depends on nc).
335       const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
336 
337       __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
338 
339       const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
340       const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16);
341       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
342 
343       const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
344       const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
345       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
346 
347       const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
348       const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48);
349       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
350 
351       const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3);
352       const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64);
353       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0);
354 
355       const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4);
356       const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 80);
357       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0);
358 
359       const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5);
360       const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96);
361       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp0);
362 
363       const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6);
364       const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 112);
365       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0);
366 
367       const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7);
368       const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128);
369       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp0);
370 
371       const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8);
372       const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 144);
373       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0);
374 
375       const __m512 vi9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i9);
376       const __m512 vk9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 160);
377       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF, vacc0123456789ABCDEFp0);
378 
379       const __m512 vi10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i10);
380       const __m512 vk10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 176);
381       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF, vacc0123456789ABCDEFp0);
382 
383       const __m512 vi11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i11);
384       const __m512 vk11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 192);
385       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF, vacc0123456789ABCDEFp0);
386 
387       const __m512 vi12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i12);
388       const __m512 vk12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 208);
389       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF, vacc0123456789ABCDEFp0);
390 
391       const __m512 vi13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i13);
392       const __m512 vk13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 224);
393       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF, vacc0123456789ABCDEFp0);
394 
395       const __m512 vi14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i14);
396       const __m512 vk14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 240);
397       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF, vacc0123456789ABCDEFp0);
398 
399       const __m512 vi15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i15);
400       const __m512 vk15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 256);
401       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF, vacc0123456789ABCDEFp0);
402 
403       const __m512 vi16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i16);
404       const __m512 vk16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 272);
405       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF, vacc0123456789ABCDEFp0);
406 
407       const __m512 vi17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i17);
408       const __m512 vk17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 288);
409       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF, vacc0123456789ABCDEFp0);
410 
411       const __m512 vi18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i18);
412       const __m512 vk18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 304);
413       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF, vacc0123456789ABCDEFp0);
414 
415       const __m512 vi19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i19);
416       const __m512 vk19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 320);
417       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF, vacc0123456789ABCDEFp0);
418 
419       const __m512 vi20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i20);
420       const __m512 vk20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 336);
421       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF, vacc0123456789ABCDEFp0);
422 
423       const __m512 vi21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i21);
424       const __m512 vk21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 352);
425       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF, vacc0123456789ABCDEFp0);
426 
427       const __m512 vi22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i22);
428       const __m512 vk22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 368);
429       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF, vacc0123456789ABCDEFp0);
430 
431       const __m512 vi23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i23);
432       const __m512 vk23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 384);
433       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF, vacc0123456789ABCDEFp0);
434 
435       const __m512 vi24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i24);
436       const __m512 vk24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 400);
437       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF, vacc0123456789ABCDEFp0);
438 
439 
440       __m512 vacc0123456789ABCDEF = _mm512_max_ps(vacc0123456789ABCDEFp0, vmin);
441       vacc0123456789ABCDEF = _mm512_min_ps(vacc0123456789ABCDEF, vmax);
442 
443       _mm512_mask_storeu_ps(output, vmask, vacc0123456789ABCDEF);
444       output += c;
445     }
446 
447     output = (float*) ((uintptr_t) output + output_increment);
448   } while (--output_width != 0);
449 }
450 
xnn_f32_dwconv_minmax_ukernel_up16x3__avx512f(size_t channels,size_t output_width,const float ** input,const float * weights,float * output,size_t input_stride,size_t output_increment,size_t input_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])451 void xnn_f32_dwconv_minmax_ukernel_up16x3__avx512f(
452     size_t channels,
453     size_t output_width,
454     const float** input,
455     const float* weights,
456     float* output,
457     size_t input_stride,
458     size_t output_increment,
459     size_t input_offset,
460     const float* zero,
461     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
462 {
463   assert(channels != 0);
464   assert(output_width != 0);
465 
466   const __m512 vmax = _mm512_set1_ps(params->scalar.max);
467   const __m512 vmin = _mm512_set1_ps(params->scalar.min);
468   do {
469     const float* i0 = input[0];
470     assert(i0 != NULL);
471     if XNN_UNPREDICTABLE(i0 != zero) {
472       i0 = (const float*) ((uintptr_t) i0 + input_offset);
473     }
474     const float* i1 = input[1];
475     assert(i1 != NULL);
476     if XNN_UNPREDICTABLE(i1 != zero) {
477       i1 = (const float*) ((uintptr_t) i1 + input_offset);
478     }
479     const float* i2 = input[2];
480     assert(i2 != NULL);
481     if XNN_UNPREDICTABLE(i2 != zero) {
482       i2 = (const float*) ((uintptr_t) i2 + input_offset);
483     }
484     input = (const float**) ((uintptr_t) input + input_stride);
485 
486     size_t c = channels;
487     const float* w = weights;
488     for (; c >= 16; c -= 16) {
489       __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
490 
491 
492       const __m512 vi0x0123456789ABCDEF = _mm512_loadu_ps(i0);
493       i0 += 16;
494 
495       const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16);
496       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
497 
498       const __m512 vi1x0123456789ABCDEF = _mm512_loadu_ps(i1);
499       i1 += 16;
500 
501       const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32);
502       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
503 
504       const __m512 vi2x0123456789ABCDEF = _mm512_loadu_ps(i2);
505       i2 += 16;
506 
507       const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48);
508       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
509 
510       w += 64;
511 
512 
513       __m512 vacc0123456789ABCDEF = _mm512_max_ps(vacc0123456789ABCDEFp0, vmin);
514       vacc0123456789ABCDEF = _mm512_min_ps(vacc0123456789ABCDEF, vmax);
515 
516       _mm512_storeu_ps(output, vacc0123456789ABCDEF);
517       output += 16;
518     }
519     if XNN_UNLIKELY(c != 0) {
520       assert(c >= 1);
521       assert(c <= 16);
522       // Prepare mask for valid 32-bit elements (depends on nc).
523       const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
524 
525       __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
526 
527       const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
528       const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16);
529       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
530 
531       const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
532       const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
533       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
534 
535       const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
536       const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48);
537       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
538 
539 
540       __m512 vacc0123456789ABCDEF = _mm512_max_ps(vacc0123456789ABCDEFp0, vmin);
541       vacc0123456789ABCDEF = _mm512_min_ps(vacc0123456789ABCDEF, vmax);
542 
543       _mm512_mask_storeu_ps(output, vmask, vacc0123456789ABCDEF);
544       output += c;
545     }
546 
547     output = (float*) ((uintptr_t) output + output_increment);
548   } while (--output_width != 0);
549 }
550 
xnn_f32_dwconv_minmax_ukernel_up16x4__avx512f(size_t channels,size_t output_width,const float ** input,const float * weights,float * output,size_t input_stride,size_t output_increment,size_t input_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])551 void xnn_f32_dwconv_minmax_ukernel_up16x4__avx512f(
552     size_t channels,
553     size_t output_width,
554     const float** input,
555     const float* weights,
556     float* output,
557     size_t input_stride,
558     size_t output_increment,
559     size_t input_offset,
560     const float* zero,
561     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
562 {
563   assert(channels != 0);
564   assert(output_width != 0);
565 
566   const __m512 vmax = _mm512_set1_ps(params->scalar.max);
567   const __m512 vmin = _mm512_set1_ps(params->scalar.min);
568   do {
569     const float* i0 = input[0];
570     assert(i0 != NULL);
571     if XNN_UNPREDICTABLE(i0 != zero) {
572       i0 = (const float*) ((uintptr_t) i0 + input_offset);
573     }
574     const float* i1 = input[1];
575     assert(i1 != NULL);
576     if XNN_UNPREDICTABLE(i1 != zero) {
577       i1 = (const float*) ((uintptr_t) i1 + input_offset);
578     }
579     const float* i2 = input[2];
580     assert(i2 != NULL);
581     if XNN_UNPREDICTABLE(i2 != zero) {
582       i2 = (const float*) ((uintptr_t) i2 + input_offset);
583     }
584     const float* i3 = input[3];
585     assert(i3 != NULL);
586     if XNN_UNPREDICTABLE(i3 != zero) {
587       i3 = (const float*) ((uintptr_t) i3 + input_offset);
588     }
589     input = (const float**) ((uintptr_t) input + input_stride);
590 
591     size_t c = channels;
592     const float* w = weights;
593     for (; c >= 16; c -= 16) {
594       __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
595 
596 
597       const __m512 vi0x0123456789ABCDEF = _mm512_loadu_ps(i0);
598       i0 += 16;
599 
600       const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16);
601       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
602 
603       const __m512 vi1x0123456789ABCDEF = _mm512_loadu_ps(i1);
604       i1 += 16;
605 
606       const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32);
607       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
608 
609       const __m512 vi2x0123456789ABCDEF = _mm512_loadu_ps(i2);
610       i2 += 16;
611 
612       const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48);
613       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
614 
615       const __m512 vi3x0123456789ABCDEF = _mm512_loadu_ps(i3);
616       i3 += 16;
617 
618       const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 64);
619       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0);
620 
621       w += 80;
622 
623 
624       __m512 vacc0123456789ABCDEF = _mm512_max_ps(vacc0123456789ABCDEFp0, vmin);
625       vacc0123456789ABCDEF = _mm512_min_ps(vacc0123456789ABCDEF, vmax);
626 
627       _mm512_storeu_ps(output, vacc0123456789ABCDEF);
628       output += 16;
629     }
630     if XNN_UNLIKELY(c != 0) {
631       assert(c >= 1);
632       assert(c <= 16);
633       // Prepare mask for valid 32-bit elements (depends on nc).
634       const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
635 
636       __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
637 
638       const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
639       const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16);
640       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
641 
642       const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
643       const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
644       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
645 
646       const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
647       const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48);
648       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
649 
650       const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3);
651       const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64);
652       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0);
653 
654 
655       __m512 vacc0123456789ABCDEF = _mm512_max_ps(vacc0123456789ABCDEFp0, vmin);
656       vacc0123456789ABCDEF = _mm512_min_ps(vacc0123456789ABCDEF, vmax);
657 
658       _mm512_mask_storeu_ps(output, vmask, vacc0123456789ABCDEF);
659       output += c;
660     }
661 
662     output = (float*) ((uintptr_t) output + output_increment);
663   } while (--output_width != 0);
664 }
665 
xnn_f32_dwconv_minmax_ukernel_up16x9__avx512f(size_t channels,size_t output_width,const float ** input,const float * weights,float * output,size_t input_stride,size_t output_increment,size_t input_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])666 void xnn_f32_dwconv_minmax_ukernel_up16x9__avx512f(
667     size_t channels,
668     size_t output_width,
669     const float** input,
670     const float* weights,
671     float* output,
672     size_t input_stride,
673     size_t output_increment,
674     size_t input_offset,
675     const float* zero,
676     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
677 {
678   assert(channels != 0);
679   assert(output_width != 0);
680 
681   const __m512 vmax = _mm512_set1_ps(params->scalar.max);
682   const __m512 vmin = _mm512_set1_ps(params->scalar.min);
683   do {
684     const float* i0 = input[0];
685     assert(i0 != NULL);
686     if XNN_UNPREDICTABLE(i0 != zero) {
687       i0 = (const float*) ((uintptr_t) i0 + input_offset);
688     }
689     const float* i1 = input[1];
690     assert(i1 != NULL);
691     if XNN_UNPREDICTABLE(i1 != zero) {
692       i1 = (const float*) ((uintptr_t) i1 + input_offset);
693     }
694     const float* i2 = input[2];
695     assert(i2 != NULL);
696     if XNN_UNPREDICTABLE(i2 != zero) {
697       i2 = (const float*) ((uintptr_t) i2 + input_offset);
698     }
699     const float* i3 = input[3];
700     assert(i3 != NULL);
701     if XNN_UNPREDICTABLE(i3 != zero) {
702       i3 = (const float*) ((uintptr_t) i3 + input_offset);
703     }
704     const float* i4 = input[4];
705     assert(i4 != NULL);
706     if XNN_UNPREDICTABLE(i4 != zero) {
707       i4 = (const float*) ((uintptr_t) i4 + input_offset);
708     }
709     const float* i5 = input[5];
710     assert(i5 != NULL);
711     if XNN_UNPREDICTABLE(i5 != zero) {
712       i5 = (const float*) ((uintptr_t) i5 + input_offset);
713     }
714     const float* i6 = input[6];
715     assert(i6 != NULL);
716     if XNN_UNPREDICTABLE(i6 != zero) {
717       i6 = (const float*) ((uintptr_t) i6 + input_offset);
718     }
719     const float* i7 = input[7];
720     assert(i7 != NULL);
721     if XNN_UNPREDICTABLE(i7 != zero) {
722       i7 = (const float*) ((uintptr_t) i7 + input_offset);
723     }
724     const float* i8 = input[8];
725     assert(i8 != NULL);
726     if XNN_UNPREDICTABLE(i8 != zero) {
727       i8 = (const float*) ((uintptr_t) i8 + input_offset);
728     }
729     input = (const float**) ((uintptr_t) input + input_stride);
730 
731     size_t c = channels;
732     const float* w = weights;
733     for (; c >= 16; c -= 16) {
734       __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
735 
736 
737       const __m512 vi0x0123456789ABCDEF = _mm512_loadu_ps(i0);
738       i0 += 16;
739 
740       const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16);
741       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
742 
743       const __m512 vi1x0123456789ABCDEF = _mm512_loadu_ps(i1);
744       i1 += 16;
745 
746       const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32);
747       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
748 
749       const __m512 vi2x0123456789ABCDEF = _mm512_loadu_ps(i2);
750       i2 += 16;
751 
752       const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48);
753       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
754 
755       const __m512 vi3x0123456789ABCDEF = _mm512_loadu_ps(i3);
756       i3 += 16;
757 
758       const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 64);
759       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0);
760 
761       const __m512 vi4x0123456789ABCDEF = _mm512_loadu_ps(i4);
762       i4 += 16;
763 
764       const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 80);
765       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0);
766 
767       const __m512 vi5x0123456789ABCDEF = _mm512_loadu_ps(i5);
768       i5 += 16;
769 
770       const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 96);
771       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp0);
772 
773       const __m512 vi6x0123456789ABCDEF = _mm512_loadu_ps(i6);
774       i6 += 16;
775 
776       const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 112);
777       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0);
778 
779       const __m512 vi7x0123456789ABCDEF = _mm512_loadu_ps(i7);
780       i7 += 16;
781 
782       const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 128);
783       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp0);
784 
785       const __m512 vi8x0123456789ABCDEF = _mm512_loadu_ps(i8);
786       i8 += 16;
787 
788       const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 144);
789       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0);
790 
791       w += 160;
792 
793 
794       __m512 vacc0123456789ABCDEF = _mm512_max_ps(vacc0123456789ABCDEFp0, vmin);
795       vacc0123456789ABCDEF = _mm512_min_ps(vacc0123456789ABCDEF, vmax);
796 
797       _mm512_storeu_ps(output, vacc0123456789ABCDEF);
798       output += 16;
799     }
800     if XNN_UNLIKELY(c != 0) {
801       assert(c >= 1);
802       assert(c <= 16);
803       // Prepare mask for valid 32-bit elements (depends on nc).
804       const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
805 
806       __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
807 
808       const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
809       const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16);
810       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
811 
812       const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
813       const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
814       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
815 
816       const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
817       const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48);
818       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
819 
820       const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3);
821       const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64);
822       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0);
823 
824       const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4);
825       const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 80);
826       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0);
827 
828       const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5);
829       const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96);
830       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp0);
831 
832       const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6);
833       const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 112);
834       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0);
835 
836       const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7);
837       const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128);
838       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp0);
839 
840       const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8);
841       const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 144);
842       vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0);
843 
844 
845       __m512 vacc0123456789ABCDEF = _mm512_max_ps(vacc0123456789ABCDEFp0, vmin);
846       vacc0123456789ABCDEF = _mm512_min_ps(vacc0123456789ABCDEF, vmax);
847 
848       _mm512_mask_storeu_ps(output, vmask, vacc0123456789ABCDEF);
849       output += c;
850     }
851 
852     output = (float*) ((uintptr_t) output + output_increment);
853   } while (--output_width != 0);
854 }
855 
xnn_f32_gemm_minmax_ukernel_1x16__avx512f_broadcast(size_t mr,size_t nc,size_t kc,const float * restrict a,size_t a_stride,const float * restrict w,float * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])856 void xnn_f32_gemm_minmax_ukernel_1x16__avx512f_broadcast(
857     size_t mr,
858     size_t nc,
859     size_t kc,
860     const float*restrict a,
861     size_t a_stride,
862     const float*restrict w,
863     float*restrict c,
864     size_t cm_stride,
865     size_t cn_stride,
866     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
867 {
868   assert(mr != 0);
869   assert(mr <= 1);
870   assert(nc != 0);
871   assert(kc != 0);
872   assert(kc % sizeof(float) == 0);
873   assert(a != NULL);
874   assert(w != NULL);
875   assert(c != NULL);
876 
877   const float* a0 = a;
878   float* c0 = c;
879 
880   do {
881     __m512 vacc0x0123456789ABCDEF = _mm512_load_ps(w);
882     w += 16;
883 
884     size_t k = kc;
885     do {
886       const __m512 vb0123456789ABCDEF = _mm512_load_ps(w);
887       w += 16;
888 
889       const __m512 va0 = _mm512_set1_ps(*a0);
890       vacc0x0123456789ABCDEF = _mm512_fmadd_ps(va0, vb0123456789ABCDEF, vacc0x0123456789ABCDEF);
891 
892       a0 += 1;
893 
894       k -= sizeof(float);
895     } while (k != 0);
896 
897     const __m512 vmin = _mm512_set1_ps(params->scalar.min);
898     vacc0x0123456789ABCDEF = _mm512_max_ps(vacc0x0123456789ABCDEF, vmin);
899 
900     const __m512 vmax = _mm512_set1_ps(params->scalar.max);
901     vacc0x0123456789ABCDEF = _mm512_min_ps(vacc0x0123456789ABCDEF, vmax);
902 
903     if XNN_LIKELY(nc >= 16) {
904       _mm512_storeu_ps(c0, vacc0x0123456789ABCDEF);
905       c0 = (float*) ((uintptr_t) c0 + cn_stride);
906 
907       a0 = (const float*) ((uintptr_t) a0 - kc);
908 
909       nc -= 16;
910     } else {
911       if (nc & 15) {
912         // Prepare mask for valid 32-bit elements (depends on nc).
913         const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << nc) - UINT32_C(1)));
914 
915         _mm512_mask_storeu_ps(c0, vmask, vacc0x0123456789ABCDEF);
916       }
917 
918       nc = 0;
919     }
920   } while (nc != 0);
921 }
922 
xnn_f32_gemm_minmax_ukernel_7x16__avx512f_broadcast(size_t mr,size_t nc,size_t kc,const float * restrict a,size_t a_stride,const float * restrict w,float * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])923 void xnn_f32_gemm_minmax_ukernel_7x16__avx512f_broadcast(
924     size_t mr,
925     size_t nc,
926     size_t kc,
927     const float*restrict a,
928     size_t a_stride,
929     const float*restrict w,
930     float*restrict c,
931     size_t cm_stride,
932     size_t cn_stride,
933     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
934 {
935   assert(mr != 0);
936   assert(mr <= 7);
937   assert(nc != 0);
938   assert(kc != 0);
939   assert(kc % sizeof(float) == 0);
940   assert(a != NULL);
941   assert(w != NULL);
942   assert(c != NULL);
943 
944   const float* a0 = a;
945   float* c0 = c;
946   const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
947   float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
948   if XNN_UNPREDICTABLE(mr < 2) {
949     a1 = a0;
950     c1 = c0;
951   }
952   const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
953   float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
954   if XNN_UNPREDICTABLE(mr <= 2) {
955     a2 = a1;
956     c2 = c1;
957   }
958   const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
959   float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
960   if XNN_UNPREDICTABLE(mr < 4) {
961     a3 = a2;
962     c3 = c2;
963   }
964   const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
965   float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
966   if XNN_UNPREDICTABLE(mr <= 4) {
967     a4 = a3;
968     c4 = c3;
969   }
970   const float* a5 = (const float*) ((uintptr_t) a4 + a_stride);
971   float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
972   if XNN_UNPREDICTABLE(mr < 6) {
973     a5 = a4;
974     c5 = c4;
975   }
976   const float* a6 = (const float*) ((uintptr_t) a5 + a_stride);
977   float* c6 = (float*) ((uintptr_t) c5 + cm_stride);
978   if XNN_UNPREDICTABLE(mr <= 6) {
979     a6 = a5;
980     c6 = c5;
981   }
982 
983   do {
984     __m512 vacc0x0123456789ABCDEF = _mm512_load_ps(w);
985     __m512 vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF;
986     __m512 vacc2x0123456789ABCDEF = vacc0x0123456789ABCDEF;
987     __m512 vacc3x0123456789ABCDEF = vacc0x0123456789ABCDEF;
988     __m512 vacc4x0123456789ABCDEF = vacc0x0123456789ABCDEF;
989     __m512 vacc5x0123456789ABCDEF = vacc0x0123456789ABCDEF;
990     __m512 vacc6x0123456789ABCDEF = vacc0x0123456789ABCDEF;
991     w += 16;
992 
993     size_t k = kc;
994     do {
995       const __m512 vb0123456789ABCDEF = _mm512_load_ps(w);
996       w += 16;
997 
998       const __m512 va0 = _mm512_set1_ps(*a0);
999       vacc0x0123456789ABCDEF = _mm512_fmadd_ps(va0, vb0123456789ABCDEF, vacc0x0123456789ABCDEF);
1000       const __m512 va1 = _mm512_set1_ps(*a1);
1001       vacc1x0123456789ABCDEF = _mm512_fmadd_ps(va1, vb0123456789ABCDEF, vacc1x0123456789ABCDEF);
1002       const __m512 va2 = _mm512_set1_ps(*a2);
1003       vacc2x0123456789ABCDEF = _mm512_fmadd_ps(va2, vb0123456789ABCDEF, vacc2x0123456789ABCDEF);
1004       const __m512 va3 = _mm512_set1_ps(*a3);
1005       vacc3x0123456789ABCDEF = _mm512_fmadd_ps(va3, vb0123456789ABCDEF, vacc3x0123456789ABCDEF);
1006       const __m512 va4 = _mm512_set1_ps(*a4);
1007       vacc4x0123456789ABCDEF = _mm512_fmadd_ps(va4, vb0123456789ABCDEF, vacc4x0123456789ABCDEF);
1008       const __m512 va5 = _mm512_set1_ps(*a5);
1009       vacc5x0123456789ABCDEF = _mm512_fmadd_ps(va5, vb0123456789ABCDEF, vacc5x0123456789ABCDEF);
1010       const __m512 va6 = _mm512_set1_ps(*a6);
1011       vacc6x0123456789ABCDEF = _mm512_fmadd_ps(va6, vb0123456789ABCDEF, vacc6x0123456789ABCDEF);
1012 
1013       a0 += 1;
1014       a1 += 1;
1015       a2 += 1;
1016       a3 += 1;
1017       a4 += 1;
1018       a5 += 1;
1019       a6 += 1;
1020 
1021       k -= sizeof(float);
1022     } while (k != 0);
1023 
1024     const __m512 vmin = _mm512_set1_ps(params->scalar.min);
1025     vacc0x0123456789ABCDEF = _mm512_max_ps(vacc0x0123456789ABCDEF, vmin);
1026     vacc1x0123456789ABCDEF = _mm512_max_ps(vacc1x0123456789ABCDEF, vmin);
1027     vacc2x0123456789ABCDEF = _mm512_max_ps(vacc2x0123456789ABCDEF, vmin);
1028     vacc3x0123456789ABCDEF = _mm512_max_ps(vacc3x0123456789ABCDEF, vmin);
1029     vacc4x0123456789ABCDEF = _mm512_max_ps(vacc4x0123456789ABCDEF, vmin);
1030     vacc5x0123456789ABCDEF = _mm512_max_ps(vacc5x0123456789ABCDEF, vmin);
1031     vacc6x0123456789ABCDEF = _mm512_max_ps(vacc6x0123456789ABCDEF, vmin);
1032 
1033     const __m512 vmax = _mm512_set1_ps(params->scalar.max);
1034     vacc0x0123456789ABCDEF = _mm512_min_ps(vacc0x0123456789ABCDEF, vmax);
1035     vacc1x0123456789ABCDEF = _mm512_min_ps(vacc1x0123456789ABCDEF, vmax);
1036     vacc2x0123456789ABCDEF = _mm512_min_ps(vacc2x0123456789ABCDEF, vmax);
1037     vacc3x0123456789ABCDEF = _mm512_min_ps(vacc3x0123456789ABCDEF, vmax);
1038     vacc4x0123456789ABCDEF = _mm512_min_ps(vacc4x0123456789ABCDEF, vmax);
1039     vacc5x0123456789ABCDEF = _mm512_min_ps(vacc5x0123456789ABCDEF, vmax);
1040     vacc6x0123456789ABCDEF = _mm512_min_ps(vacc6x0123456789ABCDEF, vmax);
1041 
1042     if XNN_LIKELY(nc >= 16) {
1043       _mm512_storeu_ps(c6, vacc6x0123456789ABCDEF);
1044       c6 = (float*) ((uintptr_t) c6 + cn_stride);
1045       _mm512_storeu_ps(c5, vacc5x0123456789ABCDEF);
1046       c5 = (float*) ((uintptr_t) c5 + cn_stride);
1047       _mm512_storeu_ps(c4, vacc4x0123456789ABCDEF);
1048       c4 = (float*) ((uintptr_t) c4 + cn_stride);
1049       _mm512_storeu_ps(c3, vacc3x0123456789ABCDEF);
1050       c3 = (float*) ((uintptr_t) c3 + cn_stride);
1051       _mm512_storeu_ps(c2, vacc2x0123456789ABCDEF);
1052       c2 = (float*) ((uintptr_t) c2 + cn_stride);
1053       _mm512_storeu_ps(c1, vacc1x0123456789ABCDEF);
1054       c1 = (float*) ((uintptr_t) c1 + cn_stride);
1055       _mm512_storeu_ps(c0, vacc0x0123456789ABCDEF);
1056       c0 = (float*) ((uintptr_t) c0 + cn_stride);
1057 
1058       a6 = (const float*) ((uintptr_t) a6 - kc);
1059       a5 = (const float*) ((uintptr_t) a5 - kc);
1060       a4 = (const float*) ((uintptr_t) a4 - kc);
1061       a3 = (const float*) ((uintptr_t) a3 - kc);
1062       a2 = (const float*) ((uintptr_t) a2 - kc);
1063       a1 = (const float*) ((uintptr_t) a1 - kc);
1064       a0 = (const float*) ((uintptr_t) a0 - kc);
1065 
1066       nc -= 16;
1067     } else {
1068       if (nc & 15) {
1069         // Prepare mask for valid 32-bit elements (depends on nc).
1070         const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << nc) - UINT32_C(1)));
1071 
1072         _mm512_mask_storeu_ps(c6, vmask, vacc6x0123456789ABCDEF);
1073         _mm512_mask_storeu_ps(c5, vmask, vacc5x0123456789ABCDEF);
1074         _mm512_mask_storeu_ps(c4, vmask, vacc4x0123456789ABCDEF);
1075         _mm512_mask_storeu_ps(c3, vmask, vacc3x0123456789ABCDEF);
1076         _mm512_mask_storeu_ps(c2, vmask, vacc2x0123456789ABCDEF);
1077         _mm512_mask_storeu_ps(c1, vmask, vacc1x0123456789ABCDEF);
1078         _mm512_mask_storeu_ps(c0, vmask, vacc0x0123456789ABCDEF);
1079       }
1080 
1081       nc = 0;
1082     }
1083   } while (nc != 0);
1084 }
1085 
xnn_f32_igemm_minmax_ukernel_1x16__avx512f_broadcast(size_t mr,size_t nc,size_t kc,size_t ks,const float ** restrict a,const float * restrict w,float * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1086 void xnn_f32_igemm_minmax_ukernel_1x16__avx512f_broadcast(
1087     size_t mr,
1088     size_t nc,
1089     size_t kc,
1090     size_t ks,
1091     const float**restrict a,
1092     const float*restrict w,
1093     float*restrict c,
1094     size_t cm_stride,
1095     size_t cn_stride,
1096     size_t a_offset,
1097     const float* zero,
1098     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
1099 {
1100   assert(mr != 0);
1101   assert(mr <= 1);
1102   assert(nc != 0);
1103   assert(kc != 0);
1104   assert(kc % sizeof(float) == 0);
1105   assert(ks != 0);
1106   assert(ks % (1 * sizeof(void*)) == 0);
1107   assert(a_offset % sizeof(float) == 0);
1108   assert(a != NULL);
1109   assert(w != NULL);
1110   assert(c != NULL);
1111 
1112   float* c0 = c;
1113 
1114   do {
1115     __m512 vacc0x0123456789ABCDEF = _mm512_load_ps(w);
1116     w += 16;
1117 
1118     size_t p = ks;
1119     do {
1120       const float* restrict a0 = a[0];
1121       assert(a0 != NULL);
1122       if XNN_UNPREDICTABLE(a0 != zero) {
1123         a0 = (const float*) ((uintptr_t) a0 + a_offset);
1124       }
1125       a += 1;
1126 
1127       size_t k = kc;
1128       do {
1129         const __m512 vb0123456789ABCDEF = _mm512_load_ps(w);
1130         w += 16;
1131 
1132         const __m512 va0 = _mm512_set1_ps(*a0);
1133         vacc0x0123456789ABCDEF = _mm512_fmadd_ps(va0, vb0123456789ABCDEF, vacc0x0123456789ABCDEF);
1134 
1135         a0 += 1;
1136 
1137         k -= sizeof(float);
1138       } while (k != 0);
1139       p -= 1 * sizeof(void*);
1140     } while (p != 0);
1141 
1142     const __m512 vmin = _mm512_set1_ps(params->scalar.min);
1143     vacc0x0123456789ABCDEF = _mm512_max_ps(vacc0x0123456789ABCDEF, vmin);
1144 
1145     const __m512 vmax = _mm512_set1_ps(params->scalar.max);
1146     vacc0x0123456789ABCDEF = _mm512_min_ps(vacc0x0123456789ABCDEF, vmax);
1147 
1148     if XNN_LIKELY(nc >= 16) {
1149       _mm512_storeu_ps(c0, vacc0x0123456789ABCDEF);
1150       c0 = (float*) ((uintptr_t) c0 + cn_stride);
1151 
1152       a = (const float**restrict) ((uintptr_t) a - ks);
1153       nc -= 16;
1154     } else {
1155       if (nc & 15) {
1156         // Prepare mask for valid 32-bit elements (depends on nc).
1157         const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << nc) - UINT32_C(1)));
1158 
1159         _mm512_mask_storeu_ps(c0, vmask, vacc0x0123456789ABCDEF);
1160       }
1161 
1162       nc = 0;
1163     }
1164   } while (nc != 0);
1165 }
1166 
xnn_f32_igemm_minmax_ukernel_7x16__avx512f_broadcast(size_t mr,size_t nc,size_t kc,size_t ks,const float ** restrict a,const float * restrict w,float * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1167 void xnn_f32_igemm_minmax_ukernel_7x16__avx512f_broadcast(
1168     size_t mr,
1169     size_t nc,
1170     size_t kc,
1171     size_t ks,
1172     const float**restrict a,
1173     const float*restrict w,
1174     float*restrict c,
1175     size_t cm_stride,
1176     size_t cn_stride,
1177     size_t a_offset,
1178     const float* zero,
1179     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
1180 {
1181   assert(mr != 0);
1182   assert(mr <= 7);
1183   assert(nc != 0);
1184   assert(kc != 0);
1185   assert(kc % sizeof(float) == 0);
1186   assert(ks != 0);
1187   assert(ks % (7 * sizeof(void*)) == 0);
1188   assert(a_offset % sizeof(float) == 0);
1189   assert(a != NULL);
1190   assert(w != NULL);
1191   assert(c != NULL);
1192 
1193   float* c0 = c;
1194   float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
1195   if XNN_UNPREDICTABLE(mr < 2) {
1196     c1 = c0;
1197   }
1198   float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
1199   if XNN_UNPREDICTABLE(mr <= 2) {
1200     c2 = c1;
1201   }
1202   float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
1203   if XNN_UNPREDICTABLE(mr < 4) {
1204     c3 = c2;
1205   }
1206   float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
1207   if XNN_UNPREDICTABLE(mr <= 4) {
1208     c4 = c3;
1209   }
1210   float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
1211   if XNN_UNPREDICTABLE(mr < 6) {
1212     c5 = c4;
1213   }
1214   float* c6 = (float*) ((uintptr_t) c5 + cm_stride);
1215   if XNN_UNPREDICTABLE(mr <= 6) {
1216     c6 = c5;
1217   }
1218 
1219   do {
1220     __m512 vacc0x0123456789ABCDEF = _mm512_load_ps(w);
1221     __m512 vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF;
1222     __m512 vacc2x0123456789ABCDEF = vacc0x0123456789ABCDEF;
1223     __m512 vacc3x0123456789ABCDEF = vacc0x0123456789ABCDEF;
1224     __m512 vacc4x0123456789ABCDEF = vacc0x0123456789ABCDEF;
1225     __m512 vacc5x0123456789ABCDEF = vacc0x0123456789ABCDEF;
1226     __m512 vacc6x0123456789ABCDEF = vacc0x0123456789ABCDEF;
1227     w += 16;
1228 
1229     size_t p = ks;
1230     do {
1231       const float* restrict a0 = a[0];
1232       assert(a0 != NULL);
1233       if XNN_UNPREDICTABLE(a0 != zero) {
1234         a0 = (const float*) ((uintptr_t) a0 + a_offset);
1235       }
1236       const float* restrict a1 = a[1];
1237       assert(a1 != NULL);
1238       if XNN_UNPREDICTABLE(a1 != zero) {
1239         a1 = (const float*) ((uintptr_t) a1 + a_offset);
1240       }
1241       const float* restrict a2 = a[2];
1242       assert(a2 != NULL);
1243       if XNN_UNPREDICTABLE(a2 != zero) {
1244         a2 = (const float*) ((uintptr_t) a2 + a_offset);
1245       }
1246       const float* restrict a3 = a[3];
1247       assert(a3 != NULL);
1248       if XNN_UNPREDICTABLE(a3 != zero) {
1249         a3 = (const float*) ((uintptr_t) a3 + a_offset);
1250       }
1251       const float* restrict a4 = a[4];
1252       assert(a4 != NULL);
1253       if XNN_UNPREDICTABLE(a4 != zero) {
1254         a4 = (const float*) ((uintptr_t) a4 + a_offset);
1255       }
1256       const float* restrict a5 = a[5];
1257       assert(a5 != NULL);
1258       if XNN_UNPREDICTABLE(a5 != zero) {
1259         a5 = (const float*) ((uintptr_t) a5 + a_offset);
1260       }
1261       const float* restrict a6 = a[6];
1262       assert(a6 != NULL);
1263       if XNN_UNPREDICTABLE(a6 != zero) {
1264         a6 = (const float*) ((uintptr_t) a6 + a_offset);
1265       }
1266       a += 7;
1267 
1268       size_t k = kc;
1269       do {
1270         const __m512 vb0123456789ABCDEF = _mm512_load_ps(w);
1271         w += 16;
1272 
1273         const __m512 va0 = _mm512_set1_ps(*a0);
1274         vacc0x0123456789ABCDEF = _mm512_fmadd_ps(va0, vb0123456789ABCDEF, vacc0x0123456789ABCDEF);
1275         const __m512 va1 = _mm512_set1_ps(*a1);
1276         vacc1x0123456789ABCDEF = _mm512_fmadd_ps(va1, vb0123456789ABCDEF, vacc1x0123456789ABCDEF);
1277         const __m512 va2 = _mm512_set1_ps(*a2);
1278         vacc2x0123456789ABCDEF = _mm512_fmadd_ps(va2, vb0123456789ABCDEF, vacc2x0123456789ABCDEF);
1279         const __m512 va3 = _mm512_set1_ps(*a3);
1280         vacc3x0123456789ABCDEF = _mm512_fmadd_ps(va3, vb0123456789ABCDEF, vacc3x0123456789ABCDEF);
1281         const __m512 va4 = _mm512_set1_ps(*a4);
1282         vacc4x0123456789ABCDEF = _mm512_fmadd_ps(va4, vb0123456789ABCDEF, vacc4x0123456789ABCDEF);
1283         const __m512 va5 = _mm512_set1_ps(*a5);
1284         vacc5x0123456789ABCDEF = _mm512_fmadd_ps(va5, vb0123456789ABCDEF, vacc5x0123456789ABCDEF);
1285         const __m512 va6 = _mm512_set1_ps(*a6);
1286         vacc6x0123456789ABCDEF = _mm512_fmadd_ps(va6, vb0123456789ABCDEF, vacc6x0123456789ABCDEF);
1287 
1288         a0 += 1;
1289         a1 += 1;
1290         a2 += 1;
1291         a3 += 1;
1292         a4 += 1;
1293         a5 += 1;
1294         a6 += 1;
1295 
1296         k -= sizeof(float);
1297       } while (k != 0);
1298       p -= 7 * sizeof(void*);
1299     } while (p != 0);
1300 
1301     const __m512 vmin = _mm512_set1_ps(params->scalar.min);
1302     vacc0x0123456789ABCDEF = _mm512_max_ps(vacc0x0123456789ABCDEF, vmin);
1303     vacc1x0123456789ABCDEF = _mm512_max_ps(vacc1x0123456789ABCDEF, vmin);
1304     vacc2x0123456789ABCDEF = _mm512_max_ps(vacc2x0123456789ABCDEF, vmin);
1305     vacc3x0123456789ABCDEF = _mm512_max_ps(vacc3x0123456789ABCDEF, vmin);
1306     vacc4x0123456789ABCDEF = _mm512_max_ps(vacc4x0123456789ABCDEF, vmin);
1307     vacc5x0123456789ABCDEF = _mm512_max_ps(vacc5x0123456789ABCDEF, vmin);
1308     vacc6x0123456789ABCDEF = _mm512_max_ps(vacc6x0123456789ABCDEF, vmin);
1309 
1310     const __m512 vmax = _mm512_set1_ps(params->scalar.max);
1311     vacc0x0123456789ABCDEF = _mm512_min_ps(vacc0x0123456789ABCDEF, vmax);
1312     vacc1x0123456789ABCDEF = _mm512_min_ps(vacc1x0123456789ABCDEF, vmax);
1313     vacc2x0123456789ABCDEF = _mm512_min_ps(vacc2x0123456789ABCDEF, vmax);
1314     vacc3x0123456789ABCDEF = _mm512_min_ps(vacc3x0123456789ABCDEF, vmax);
1315     vacc4x0123456789ABCDEF = _mm512_min_ps(vacc4x0123456789ABCDEF, vmax);
1316     vacc5x0123456789ABCDEF = _mm512_min_ps(vacc5x0123456789ABCDEF, vmax);
1317     vacc6x0123456789ABCDEF = _mm512_min_ps(vacc6x0123456789ABCDEF, vmax);
1318 
1319     if XNN_LIKELY(nc >= 16) {
1320       _mm512_storeu_ps(c6, vacc6x0123456789ABCDEF);
1321       c6 = (float*) ((uintptr_t) c6 + cn_stride);
1322       _mm512_storeu_ps(c5, vacc5x0123456789ABCDEF);
1323       c5 = (float*) ((uintptr_t) c5 + cn_stride);
1324       _mm512_storeu_ps(c4, vacc4x0123456789ABCDEF);
1325       c4 = (float*) ((uintptr_t) c4 + cn_stride);
1326       _mm512_storeu_ps(c3, vacc3x0123456789ABCDEF);
1327       c3 = (float*) ((uintptr_t) c3 + cn_stride);
1328       _mm512_storeu_ps(c2, vacc2x0123456789ABCDEF);
1329       c2 = (float*) ((uintptr_t) c2 + cn_stride);
1330       _mm512_storeu_ps(c1, vacc1x0123456789ABCDEF);
1331       c1 = (float*) ((uintptr_t) c1 + cn_stride);
1332       _mm512_storeu_ps(c0, vacc0x0123456789ABCDEF);
1333       c0 = (float*) ((uintptr_t) c0 + cn_stride);
1334 
1335       a = (const float**restrict) ((uintptr_t) a - ks);
1336       nc -= 16;
1337     } else {
1338       if (nc & 15) {
1339         // Prepare mask for valid 32-bit elements (depends on nc).
1340         const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << nc) - UINT32_C(1)));
1341 
1342         _mm512_mask_storeu_ps(c6, vmask, vacc6x0123456789ABCDEF);
1343         _mm512_mask_storeu_ps(c5, vmask, vacc5x0123456789ABCDEF);
1344         _mm512_mask_storeu_ps(c4, vmask, vacc4x0123456789ABCDEF);
1345         _mm512_mask_storeu_ps(c3, vmask, vacc3x0123456789ABCDEF);
1346         _mm512_mask_storeu_ps(c2, vmask, vacc2x0123456789ABCDEF);
1347         _mm512_mask_storeu_ps(c1, vmask, vacc1x0123456789ABCDEF);
1348         _mm512_mask_storeu_ps(c0, vmask, vacc0x0123456789ABCDEF);
1349       }
1350 
1351       nc = 0;
1352     }
1353   } while (nc != 0);
1354 }
1355 
xnn_f32_prelu_ukernel__avx512f_2x16(size_t rows,size_t channels,const float * restrict input,size_t input_stride,const float * restrict weights,float * restrict output,size_t output_stride)1356 void xnn_f32_prelu_ukernel__avx512f_2x16(
1357     size_t rows,
1358     size_t channels,
1359     const float*restrict input,
1360     size_t input_stride,
1361     const float*restrict weights,
1362     float*restrict output,
1363     size_t output_stride)
1364 {
1365   assert(rows != 0);
1366   assert(channels != 0);
1367   assert(channels % sizeof(float) == 0);
1368 
1369   const float* i0 = input;
1370   float* o0 = output;
1371   const float* i1 = (const float*) ((uintptr_t) i0 + input_stride);
1372   float* o1 = (float*) ((uintptr_t) o0 + output_stride);
1373 
1374   const size_t input_increment = input_stride * 2 - channels;
1375   const size_t output_increment = output_stride * 2 - channels;
1376 
1377   const __m512 vzero = _mm512_setzero_ps();
1378   do {
1379     if XNN_UNPREDICTABLE(rows < 2) {
1380       i1 = i0;
1381       o1 = o0;
1382     }
1383 
1384     const float* w = weights;
1385     size_t c = channels;
1386     for (; c >= 16 * sizeof(float); c -= 16 * sizeof(float)) {
1387       const __m512 vw0123456789ABCDEF = _mm512_load_ps(w);
1388       w += 16;
1389 
1390       const __m512 vi0x0123456789ABCDEF = _mm512_loadu_ps(i0);
1391       i0 += 16;
1392       const __m512 vi1x0123456789ABCDEF = _mm512_loadu_ps(i1);
1393       i1 += 16;
1394 
1395       const __mmask16 vsign0x0123456789ABCDEF = _mm512_cmp_ps_mask(vi0x0123456789ABCDEF, vzero, _CMP_LT_OQ);
1396       const __m512 vacc0x0123456789ABCDEF = _mm512_mask_mul_ps(vi0x0123456789ABCDEF, vsign0x0123456789ABCDEF, vi0x0123456789ABCDEF, vw0123456789ABCDEF);
1397       const __mmask16 vsign1x0123456789ABCDEF = _mm512_cmp_ps_mask(vi1x0123456789ABCDEF, vzero, _CMP_LT_OQ);
1398       const __m512 vacc1x0123456789ABCDEF = _mm512_mask_mul_ps(vi1x0123456789ABCDEF, vsign1x0123456789ABCDEF, vi1x0123456789ABCDEF, vw0123456789ABCDEF);
1399 
1400       _mm512_storeu_ps(o0, vacc0x0123456789ABCDEF);
1401       o0 += 16;
1402       _mm512_storeu_ps(o1, vacc1x0123456789ABCDEF);
1403       o1 += 16;
1404     }
1405     if XNN_UNLIKELY(c != 0) {
1406       assert(c >= 1 * sizeof(float));
1407       assert(c <= 15 * sizeof(float));
1408       // Prepare mask for valid 32-bit elements (depends on c).
1409       const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << (c >> 2 /* log2(sizeof(float))*/)) - UINT32_C(1)));
1410 
1411       const __m512 vw = _mm512_maskz_loadu_ps(vmask, w);
1412 
1413       const __m512 vi0 = _mm512_maskz_loadu_ps(vmask, i0);
1414       i0 = (const float*) ((uintptr_t) i0 + c);
1415       const __m512 vi1 = _mm512_maskz_loadu_ps(vmask, i1);
1416       i1 = (const float*) ((uintptr_t) i1 + c);
1417 
1418       const __mmask16 vsign0 = _mm512_cmp_ps_mask(vi0, vzero, _CMP_LT_OQ);
1419       const __m512 vacc0 = _mm512_mask_mul_ps(vi0, vsign0, vi0, vw);
1420       const __mmask16 vsign1 = _mm512_cmp_ps_mask(vi1, vzero, _CMP_LT_OQ);
1421       const __m512 vacc1 = _mm512_mask_mul_ps(vi1, vsign1, vi1, vw);
1422 
1423       _mm512_mask_storeu_ps(o0, vmask, vacc0);
1424       o0 = (float*) ((uintptr_t) o0 + c);
1425       _mm512_mask_storeu_ps(o1, vmask, vacc1);
1426       o1 = (float*) ((uintptr_t) o1 + c);
1427     }
1428     i0 = (const float*) ((uintptr_t) i0 + input_increment);
1429     o0 = (float*) ((uintptr_t) o0 + output_increment);
1430     i1 = (const float*) ((uintptr_t) i1 + input_increment);
1431     o1 = (float*) ((uintptr_t) o1 + output_increment);
1432     rows = doz(rows, 2);
1433   } while (rows != 0);
1434 }
1435 
xnn_f32_vadd_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1436 void xnn_f32_vadd_minmax_ukernel__avx512f_x32(
1437     size_t n,
1438     const float* a,
1439     const float* b,
1440     float* y,
1441     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
1442 {
1443   assert(n != 0);
1444   assert(n % sizeof(float) == 0);
1445   assert(a != NULL);
1446   assert(b != NULL);
1447   assert(y != NULL);
1448 
1449   const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
1450   const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
1451 
1452   for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
1453     const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
1454     const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
1455     a += 32;
1456 
1457     const __m512 vb0123456789ABCDEF = _mm512_loadu_ps(b);
1458     const __m512 vbGHIJKLMNOPQRSTUV = _mm512_loadu_ps(b + 16);
1459     b += 32;
1460 
1461     __m512 vy0123456789ABCDEF = _mm512_add_ps(va0123456789ABCDEF, vb0123456789ABCDEF);
1462     __m512 vyGHIJKLMNOPQRSTUV = _mm512_add_ps(vaGHIJKLMNOPQRSTUV, vbGHIJKLMNOPQRSTUV);
1463 
1464 
1465     vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
1466     vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
1467 
1468     vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
1469     vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
1470 
1471     _mm512_storeu_ps(y, vy0123456789ABCDEF);
1472     _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
1473     y += 32;
1474   }
1475   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
1476     const __m512 va = _mm512_loadu_ps(a);
1477     a += 16;
1478 
1479     const __m512 vb = _mm512_loadu_ps(b);
1480     b += 16;
1481 
1482     __m512 vy = _mm512_add_ps(va, vb);
1483     vy = _mm512_max_ps(vy, vy_min);
1484     vy = _mm512_min_ps(vy, vy_max);
1485     _mm512_storeu_ps(y, vy);
1486     y += 16;
1487   }
1488   if XNN_UNLIKELY(n != 0) {
1489     assert(n >= 1 * sizeof(float));
1490     assert(n <= 15 * sizeof(float));
1491     // Prepare mask for valid 32-bit elements (depends on n).
1492     n >>= 2 /* log2(sizeof(float)) */;
1493     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
1494 
1495     const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
1496     const __m512 vb = _mm512_maskz_loadu_ps(vmask, b);
1497 
1498     __m512 vy = _mm512_add_ps(va, vb);
1499     vy = _mm512_max_ps(vy, vy_min);
1500     vy = _mm512_min_ps(vy, vy_max);
1501     _mm512_mask_storeu_ps(y, vmask, vy);
1502   }
1503 }
1504 
xnn_f32_vaddc_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1505 void xnn_f32_vaddc_minmax_ukernel__avx512f_x32(
1506     size_t n,
1507     const float* a,
1508     const float* b,
1509     float* y,
1510     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
1511 {
1512   assert(n != 0);
1513   assert(n % sizeof(float) == 0);
1514   assert(a != NULL);
1515   assert(b != NULL);
1516   assert(y != NULL);
1517 
1518   const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
1519   const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
1520 
1521   const __m512 vb = _mm512_set1_ps(*b);
1522   for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
1523     const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
1524     const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
1525     a += 32;
1526 
1527     __m512 vy0123456789ABCDEF = _mm512_add_ps(va0123456789ABCDEF, vb);
1528     __m512 vyGHIJKLMNOPQRSTUV = _mm512_add_ps(vaGHIJKLMNOPQRSTUV, vb);
1529 
1530 
1531     vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
1532     vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
1533 
1534     vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
1535     vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
1536 
1537     _mm512_storeu_ps(y, vy0123456789ABCDEF);
1538     _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
1539     y += 32;
1540   }
1541   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
1542     const __m512 va = _mm512_loadu_ps(a);
1543     a += 16;
1544 
1545     __m512 vy = _mm512_add_ps(va, vb);
1546     vy = _mm512_max_ps(vy, vy_min);
1547     vy = _mm512_min_ps(vy, vy_max);
1548     _mm512_storeu_ps(y, vy);
1549     y += 16;
1550   }
1551   if XNN_UNLIKELY(n != 0) {
1552     assert(n >= 1 * sizeof(float));
1553     assert(n <= 15 * sizeof(float));
1554     // Prepare mask for valid 32-bit elements (depends on n).
1555     n >>= 2 /* log2(sizeof(float)) */;
1556     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
1557 
1558     const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
1559 
1560     __m512 vy = _mm512_add_ps(va, vb);
1561     vy = _mm512_max_ps(vy, vy_min);
1562     vy = _mm512_min_ps(vy, vy_max);
1563     _mm512_mask_storeu_ps(y, vmask, vy);
1564   }
1565 }
1566 
xnn_f32_vdiv_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1567 void xnn_f32_vdiv_minmax_ukernel__avx512f_x32(
1568     size_t n,
1569     const float* a,
1570     const float* b,
1571     float* y,
1572     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
1573 {
1574   assert(n != 0);
1575   assert(n % sizeof(float) == 0);
1576   assert(a != NULL);
1577   assert(b != NULL);
1578   assert(y != NULL);
1579 
1580   const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
1581   const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
1582 
1583   for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
1584     const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
1585     const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
1586     a += 32;
1587 
1588     const __m512 vb0123456789ABCDEF = _mm512_loadu_ps(b);
1589     const __m512 vbGHIJKLMNOPQRSTUV = _mm512_loadu_ps(b + 16);
1590     b += 32;
1591 
1592     __m512 vy0123456789ABCDEF = _mm512_div_ps(va0123456789ABCDEF, vb0123456789ABCDEF);
1593     __m512 vyGHIJKLMNOPQRSTUV = _mm512_div_ps(vaGHIJKLMNOPQRSTUV, vbGHIJKLMNOPQRSTUV);
1594 
1595 
1596     vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
1597     vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
1598 
1599     vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
1600     vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
1601 
1602     _mm512_storeu_ps(y, vy0123456789ABCDEF);
1603     _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
1604     y += 32;
1605   }
1606   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
1607     const __m512 va = _mm512_loadu_ps(a);
1608     a += 16;
1609 
1610     const __m512 vb = _mm512_loadu_ps(b);
1611     b += 16;
1612 
1613     __m512 vy = _mm512_div_ps(va, vb);
1614     vy = _mm512_max_ps(vy, vy_min);
1615     vy = _mm512_min_ps(vy, vy_max);
1616     _mm512_storeu_ps(y, vy);
1617     y += 16;
1618   }
1619   if XNN_UNLIKELY(n != 0) {
1620     assert(n >= 1 * sizeof(float));
1621     assert(n <= 15 * sizeof(float));
1622     // Prepare mask for valid 32-bit elements (depends on n).
1623     n >>= 2 /* log2(sizeof(float)) */;
1624     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
1625 
1626     const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
1627     const __m512 vb = _mm512_maskz_loadu_ps(vmask, b);
1628 
1629     __m512 vy = _mm512_div_ps(va, vb);
1630     vy = _mm512_max_ps(vy, vy_min);
1631     vy = _mm512_min_ps(vy, vy_max);
1632     _mm512_mask_storeu_ps(y, vmask, vy);
1633   }
1634 }
1635 
xnn_f32_vdivc_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1636 void xnn_f32_vdivc_minmax_ukernel__avx512f_x32(
1637     size_t n,
1638     const float* a,
1639     const float* b,
1640     float* y,
1641     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
1642 {
1643   assert(n != 0);
1644   assert(n % sizeof(float) == 0);
1645   assert(a != NULL);
1646   assert(b != NULL);
1647   assert(y != NULL);
1648 
1649   const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
1650   const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
1651 
1652   const __m512 vb = _mm512_set1_ps(*b);
1653   for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
1654     const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
1655     const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
1656     a += 32;
1657 
1658     __m512 vy0123456789ABCDEF = _mm512_div_ps(va0123456789ABCDEF, vb);
1659     __m512 vyGHIJKLMNOPQRSTUV = _mm512_div_ps(vaGHIJKLMNOPQRSTUV, vb);
1660 
1661 
1662     vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
1663     vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
1664 
1665     vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
1666     vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
1667 
1668     _mm512_storeu_ps(y, vy0123456789ABCDEF);
1669     _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
1670     y += 32;
1671   }
1672   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
1673     const __m512 va = _mm512_loadu_ps(a);
1674     a += 16;
1675 
1676     __m512 vy = _mm512_div_ps(va, vb);
1677     vy = _mm512_max_ps(vy, vy_min);
1678     vy = _mm512_min_ps(vy, vy_max);
1679     _mm512_storeu_ps(y, vy);
1680     y += 16;
1681   }
1682   if XNN_UNLIKELY(n != 0) {
1683     assert(n >= 1 * sizeof(float));
1684     assert(n <= 15 * sizeof(float));
1685     // Prepare mask for valid 32-bit elements (depends on n).
1686     n >>= 2 /* log2(sizeof(float)) */;
1687     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
1688 
1689     const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
1690 
1691     __m512 vy = _mm512_div_ps(va, vb);
1692     vy = _mm512_max_ps(vy, vy_min);
1693     vy = _mm512_min_ps(vy, vy_max);
1694     _mm512_mask_storeu_ps(y, vmask, vy);
1695   }
1696 }
1697 
xnn_f32_vmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS (1)])1698 void xnn_f32_vmax_ukernel__avx512f_x32(
1699     size_t n,
1700     const float* a,
1701     const float* b,
1702     float* y,
1703     const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
1704 {
1705   assert(n != 0);
1706   assert(n % sizeof(float) == 0);
1707   assert(a != NULL);
1708   assert(b != NULL);
1709   assert(y != NULL);
1710 
1711 
1712   for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
1713     const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
1714     const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
1715     a += 32;
1716 
1717     const __m512 vb0123456789ABCDEF = _mm512_loadu_ps(b);
1718     const __m512 vbGHIJKLMNOPQRSTUV = _mm512_loadu_ps(b + 16);
1719     b += 32;
1720 
1721     __m512 vy0123456789ABCDEF = _mm512_max_ps(va0123456789ABCDEF, vb0123456789ABCDEF);
1722     __m512 vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vaGHIJKLMNOPQRSTUV, vbGHIJKLMNOPQRSTUV);
1723 
1724 
1725 
1726     _mm512_storeu_ps(y, vy0123456789ABCDEF);
1727     _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
1728     y += 32;
1729   }
1730   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
1731     const __m512 va = _mm512_loadu_ps(a);
1732     a += 16;
1733 
1734     const __m512 vb = _mm512_loadu_ps(b);
1735     b += 16;
1736 
1737     __m512 vy = _mm512_max_ps(va, vb);
1738     _mm512_storeu_ps(y, vy);
1739     y += 16;
1740   }
1741   if XNN_UNLIKELY(n != 0) {
1742     assert(n >= 1 * sizeof(float));
1743     assert(n <= 15 * sizeof(float));
1744     // Prepare mask for valid 32-bit elements (depends on n).
1745     n >>= 2 /* log2(sizeof(float)) */;
1746     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
1747 
1748     const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
1749     const __m512 vb = _mm512_maskz_loadu_ps(vmask, b);
1750 
1751     __m512 vy = _mm512_max_ps(va, vb);
1752     _mm512_mask_storeu_ps(y, vmask, vy);
1753   }
1754 }
1755 
xnn_f32_vmaxc_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS (1)])1756 void xnn_f32_vmaxc_ukernel__avx512f_x32(
1757     size_t n,
1758     const float* a,
1759     const float* b,
1760     float* y,
1761     const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
1762 {
1763   assert(n != 0);
1764   assert(n % sizeof(float) == 0);
1765   assert(a != NULL);
1766   assert(b != NULL);
1767   assert(y != NULL);
1768 
1769 
1770   const __m512 vb = _mm512_set1_ps(*b);
1771   for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
1772     const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
1773     const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
1774     a += 32;
1775 
1776     __m512 vy0123456789ABCDEF = _mm512_max_ps(va0123456789ABCDEF, vb);
1777     __m512 vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vaGHIJKLMNOPQRSTUV, vb);
1778 
1779 
1780 
1781     _mm512_storeu_ps(y, vy0123456789ABCDEF);
1782     _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
1783     y += 32;
1784   }
1785   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
1786     const __m512 va = _mm512_loadu_ps(a);
1787     a += 16;
1788 
1789     __m512 vy = _mm512_max_ps(va, vb);
1790     _mm512_storeu_ps(y, vy);
1791     y += 16;
1792   }
1793   if XNN_UNLIKELY(n != 0) {
1794     assert(n >= 1 * sizeof(float));
1795     assert(n <= 15 * sizeof(float));
1796     // Prepare mask for valid 32-bit elements (depends on n).
1797     n >>= 2 /* log2(sizeof(float)) */;
1798     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
1799 
1800     const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
1801 
1802     __m512 vy = _mm512_max_ps(va, vb);
1803     _mm512_mask_storeu_ps(y, vmask, vy);
1804   }
1805 }
1806 
xnn_f32_vmin_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS (1)])1807 void xnn_f32_vmin_ukernel__avx512f_x32(
1808     size_t n,
1809     const float* a,
1810     const float* b,
1811     float* y,
1812     const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
1813 {
1814   assert(n != 0);
1815   assert(n % sizeof(float) == 0);
1816   assert(a != NULL);
1817   assert(b != NULL);
1818   assert(y != NULL);
1819 
1820 
1821   for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
1822     const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
1823     const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
1824     a += 32;
1825 
1826     const __m512 vb0123456789ABCDEF = _mm512_loadu_ps(b);
1827     const __m512 vbGHIJKLMNOPQRSTUV = _mm512_loadu_ps(b + 16);
1828     b += 32;
1829 
1830     __m512 vy0123456789ABCDEF = _mm512_min_ps(va0123456789ABCDEF, vb0123456789ABCDEF);
1831     __m512 vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vaGHIJKLMNOPQRSTUV, vbGHIJKLMNOPQRSTUV);
1832 
1833 
1834 
1835     _mm512_storeu_ps(y, vy0123456789ABCDEF);
1836     _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
1837     y += 32;
1838   }
1839   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
1840     const __m512 va = _mm512_loadu_ps(a);
1841     a += 16;
1842 
1843     const __m512 vb = _mm512_loadu_ps(b);
1844     b += 16;
1845 
1846     __m512 vy = _mm512_min_ps(va, vb);
1847     _mm512_storeu_ps(y, vy);
1848     y += 16;
1849   }
1850   if XNN_UNLIKELY(n != 0) {
1851     assert(n >= 1 * sizeof(float));
1852     assert(n <= 15 * sizeof(float));
1853     // Prepare mask for valid 32-bit elements (depends on n).
1854     n >>= 2 /* log2(sizeof(float)) */;
1855     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
1856 
1857     const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
1858     const __m512 vb = _mm512_maskz_loadu_ps(vmask, b);
1859 
1860     __m512 vy = _mm512_min_ps(va, vb);
1861     _mm512_mask_storeu_ps(y, vmask, vy);
1862   }
1863 }
1864 
xnn_f32_vminc_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS (1)])1865 void xnn_f32_vminc_ukernel__avx512f_x32(
1866     size_t n,
1867     const float* a,
1868     const float* b,
1869     float* y,
1870     const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
1871 {
1872   assert(n != 0);
1873   assert(n % sizeof(float) == 0);
1874   assert(a != NULL);
1875   assert(b != NULL);
1876   assert(y != NULL);
1877 
1878 
1879   const __m512 vb = _mm512_set1_ps(*b);
1880   for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
1881     const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
1882     const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
1883     a += 32;
1884 
1885     __m512 vy0123456789ABCDEF = _mm512_min_ps(va0123456789ABCDEF, vb);
1886     __m512 vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vaGHIJKLMNOPQRSTUV, vb);
1887 
1888 
1889 
1890     _mm512_storeu_ps(y, vy0123456789ABCDEF);
1891     _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
1892     y += 32;
1893   }
1894   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
1895     const __m512 va = _mm512_loadu_ps(a);
1896     a += 16;
1897 
1898     __m512 vy = _mm512_min_ps(va, vb);
1899     _mm512_storeu_ps(y, vy);
1900     y += 16;
1901   }
1902   if XNN_UNLIKELY(n != 0) {
1903     assert(n >= 1 * sizeof(float));
1904     assert(n <= 15 * sizeof(float));
1905     // Prepare mask for valid 32-bit elements (depends on n).
1906     n >>= 2 /* log2(sizeof(float)) */;
1907     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
1908 
1909     const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
1910 
1911     __m512 vy = _mm512_min_ps(va, vb);
1912     _mm512_mask_storeu_ps(y, vmask, vy);
1913   }
1914 }
1915 
xnn_f32_vmul_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1916 void xnn_f32_vmul_minmax_ukernel__avx512f_x32(
1917     size_t n,
1918     const float* a,
1919     const float* b,
1920     float* y,
1921     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
1922 {
1923   assert(n != 0);
1924   assert(n % sizeof(float) == 0);
1925   assert(a != NULL);
1926   assert(b != NULL);
1927   assert(y != NULL);
1928 
1929   const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
1930   const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
1931 
1932   for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
1933     const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
1934     const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
1935     a += 32;
1936 
1937     const __m512 vb0123456789ABCDEF = _mm512_loadu_ps(b);
1938     const __m512 vbGHIJKLMNOPQRSTUV = _mm512_loadu_ps(b + 16);
1939     b += 32;
1940 
1941     __m512 vy0123456789ABCDEF = _mm512_mul_ps(va0123456789ABCDEF, vb0123456789ABCDEF);
1942     __m512 vyGHIJKLMNOPQRSTUV = _mm512_mul_ps(vaGHIJKLMNOPQRSTUV, vbGHIJKLMNOPQRSTUV);
1943 
1944 
1945     vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
1946     vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
1947 
1948     vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
1949     vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
1950 
1951     _mm512_storeu_ps(y, vy0123456789ABCDEF);
1952     _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
1953     y += 32;
1954   }
1955   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
1956     const __m512 va = _mm512_loadu_ps(a);
1957     a += 16;
1958 
1959     const __m512 vb = _mm512_loadu_ps(b);
1960     b += 16;
1961 
1962     __m512 vy = _mm512_mul_ps(va, vb);
1963     vy = _mm512_max_ps(vy, vy_min);
1964     vy = _mm512_min_ps(vy, vy_max);
1965     _mm512_storeu_ps(y, vy);
1966     y += 16;
1967   }
1968   if XNN_UNLIKELY(n != 0) {
1969     assert(n >= 1 * sizeof(float));
1970     assert(n <= 15 * sizeof(float));
1971     // Prepare mask for valid 32-bit elements (depends on n).
1972     n >>= 2 /* log2(sizeof(float)) */;
1973     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
1974 
1975     const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
1976     const __m512 vb = _mm512_maskz_loadu_ps(vmask, b);
1977 
1978     __m512 vy = _mm512_mul_ps(va, vb);
1979     vy = _mm512_max_ps(vy, vy_min);
1980     vy = _mm512_min_ps(vy, vy_max);
1981     _mm512_mask_storeu_ps(y, vmask, vy);
1982   }
1983 }
1984 
xnn_f32_vmulc_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1985 void xnn_f32_vmulc_minmax_ukernel__avx512f_x32(
1986     size_t n,
1987     const float* a,
1988     const float* b,
1989     float* y,
1990     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
1991 {
1992   assert(n != 0);
1993   assert(n % sizeof(float) == 0);
1994   assert(a != NULL);
1995   assert(b != NULL);
1996   assert(y != NULL);
1997 
1998   const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
1999   const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
2000 
2001   const __m512 vb = _mm512_set1_ps(*b);
2002   for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
2003     const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
2004     const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
2005     a += 32;
2006 
2007     __m512 vy0123456789ABCDEF = _mm512_mul_ps(va0123456789ABCDEF, vb);
2008     __m512 vyGHIJKLMNOPQRSTUV = _mm512_mul_ps(vaGHIJKLMNOPQRSTUV, vb);
2009 
2010 
2011     vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
2012     vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
2013 
2014     vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
2015     vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
2016 
2017     _mm512_storeu_ps(y, vy0123456789ABCDEF);
2018     _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
2019     y += 32;
2020   }
2021   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2022     const __m512 va = _mm512_loadu_ps(a);
2023     a += 16;
2024 
2025     __m512 vy = _mm512_mul_ps(va, vb);
2026     vy = _mm512_max_ps(vy, vy_min);
2027     vy = _mm512_min_ps(vy, vy_max);
2028     _mm512_storeu_ps(y, vy);
2029     y += 16;
2030   }
2031   if XNN_UNLIKELY(n != 0) {
2032     assert(n >= 1 * sizeof(float));
2033     assert(n <= 15 * sizeof(float));
2034     // Prepare mask for valid 32-bit elements (depends on n).
2035     n >>= 2 /* log2(sizeof(float)) */;
2036     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2037 
2038     const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
2039 
2040     __m512 vy = _mm512_mul_ps(va, vb);
2041     vy = _mm512_max_ps(vy, vy_min);
2042     vy = _mm512_min_ps(vy, vy_max);
2043     _mm512_mask_storeu_ps(y, vmask, vy);
2044   }
2045 }
2046 
xnn_f32_vrdivc_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2047 void xnn_f32_vrdivc_minmax_ukernel__avx512f_x32(
2048     size_t n,
2049     const float* a,
2050     const float* b,
2051     float* y,
2052     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
2053 {
2054   assert(n != 0);
2055   assert(n % sizeof(float) == 0);
2056   assert(a != NULL);
2057   assert(b != NULL);
2058   assert(y != NULL);
2059 
2060   const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
2061   const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
2062 
2063   const __m512 vb = _mm512_set1_ps(*b);
2064   for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
2065     const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
2066     const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
2067     a += 32;
2068 
2069     __m512 vy0123456789ABCDEF = _mm512_div_ps(vb, va0123456789ABCDEF);
2070     __m512 vyGHIJKLMNOPQRSTUV = _mm512_div_ps(vb, vaGHIJKLMNOPQRSTUV);
2071 
2072 
2073     vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
2074     vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
2075 
2076     vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
2077     vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
2078 
2079     _mm512_storeu_ps(y, vy0123456789ABCDEF);
2080     _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
2081     y += 32;
2082   }
2083   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2084     const __m512 va = _mm512_loadu_ps(a);
2085     a += 16;
2086 
2087     __m512 vy = _mm512_div_ps(vb, va);
2088     vy = _mm512_max_ps(vy, vy_min);
2089     vy = _mm512_min_ps(vy, vy_max);
2090     _mm512_storeu_ps(y, vy);
2091     y += 16;
2092   }
2093   if XNN_UNLIKELY(n != 0) {
2094     assert(n >= 1 * sizeof(float));
2095     assert(n <= 15 * sizeof(float));
2096     // Prepare mask for valid 32-bit elements (depends on n).
2097     n >>= 2 /* log2(sizeof(float)) */;
2098     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2099 
2100     const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
2101 
2102     __m512 vy = _mm512_div_ps(vb, va);
2103     vy = _mm512_max_ps(vy, vy_min);
2104     vy = _mm512_min_ps(vy, vy_max);
2105     _mm512_mask_storeu_ps(y, vmask, vy);
2106   }
2107 }
2108 
xnn_f32_vrsubc_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2109 void xnn_f32_vrsubc_minmax_ukernel__avx512f_x32(
2110     size_t n,
2111     const float* a,
2112     const float* b,
2113     float* y,
2114     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
2115 {
2116   assert(n != 0);
2117   assert(n % sizeof(float) == 0);
2118   assert(a != NULL);
2119   assert(b != NULL);
2120   assert(y != NULL);
2121 
2122   const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
2123   const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
2124 
2125   const __m512 vb = _mm512_set1_ps(*b);
2126   for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
2127     const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
2128     const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
2129     a += 32;
2130 
2131     __m512 vy0123456789ABCDEF = _mm512_sub_ps(vb, va0123456789ABCDEF);
2132     __m512 vyGHIJKLMNOPQRSTUV = _mm512_sub_ps(vb, vaGHIJKLMNOPQRSTUV);
2133 
2134 
2135     vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
2136     vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
2137 
2138     vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
2139     vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
2140 
2141     _mm512_storeu_ps(y, vy0123456789ABCDEF);
2142     _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
2143     y += 32;
2144   }
2145   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2146     const __m512 va = _mm512_loadu_ps(a);
2147     a += 16;
2148 
2149     __m512 vy = _mm512_sub_ps(vb, va);
2150     vy = _mm512_max_ps(vy, vy_min);
2151     vy = _mm512_min_ps(vy, vy_max);
2152     _mm512_storeu_ps(y, vy);
2153     y += 16;
2154   }
2155   if XNN_UNLIKELY(n != 0) {
2156     assert(n >= 1 * sizeof(float));
2157     assert(n <= 15 * sizeof(float));
2158     // Prepare mask for valid 32-bit elements (depends on n).
2159     n >>= 2 /* log2(sizeof(float)) */;
2160     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2161 
2162     const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
2163 
2164     __m512 vy = _mm512_sub_ps(vb, va);
2165     vy = _mm512_max_ps(vy, vy_min);
2166     vy = _mm512_min_ps(vy, vy_max);
2167     _mm512_mask_storeu_ps(y, vmask, vy);
2168   }
2169 }
2170 
xnn_f32_vsqrdiff_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS (1)])2171 void xnn_f32_vsqrdiff_ukernel__avx512f_x32(
2172     size_t n,
2173     const float* a,
2174     const float* b,
2175     float* y,
2176     const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
2177 {
2178   assert(n != 0);
2179   assert(n % sizeof(float) == 0);
2180   assert(a != NULL);
2181   assert(b != NULL);
2182   assert(y != NULL);
2183 
2184 
2185   for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
2186     const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
2187     const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
2188     a += 32;
2189 
2190     const __m512 vb0123456789ABCDEF = _mm512_loadu_ps(b);
2191     const __m512 vbGHIJKLMNOPQRSTUV = _mm512_loadu_ps(b + 16);
2192     b += 32;
2193 
2194     __m512 vy0123456789ABCDEF = _mm512_sub_ps(va0123456789ABCDEF, vb0123456789ABCDEF);
2195     __m512 vyGHIJKLMNOPQRSTUV = _mm512_sub_ps(vaGHIJKLMNOPQRSTUV, vbGHIJKLMNOPQRSTUV);
2196 
2197     vy0123456789ABCDEF = _mm512_mul_ps(vy0123456789ABCDEF, vy0123456789ABCDEF);
2198     vyGHIJKLMNOPQRSTUV = _mm512_mul_ps(vyGHIJKLMNOPQRSTUV, vyGHIJKLMNOPQRSTUV);
2199 
2200 
2201     _mm512_storeu_ps(y, vy0123456789ABCDEF);
2202     _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
2203     y += 32;
2204   }
2205   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2206     const __m512 va = _mm512_loadu_ps(a);
2207     a += 16;
2208 
2209     const __m512 vb = _mm512_loadu_ps(b);
2210     b += 16;
2211 
2212     __m512 vy = _mm512_sub_ps(va, vb);
2213     vy = _mm512_mul_ps(vy, vy);
2214     _mm512_storeu_ps(y, vy);
2215     y += 16;
2216   }
2217   if XNN_UNLIKELY(n != 0) {
2218     assert(n >= 1 * sizeof(float));
2219     assert(n <= 15 * sizeof(float));
2220     // Prepare mask for valid 32-bit elements (depends on n).
2221     n >>= 2 /* log2(sizeof(float)) */;
2222     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2223 
2224     const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
2225     const __m512 vb = _mm512_maskz_loadu_ps(vmask, b);
2226 
2227     __m512 vy = _mm512_sub_ps(va, vb);
2228     vy = _mm512_mul_ps(vy, vy);
2229     _mm512_mask_storeu_ps(y, vmask, vy);
2230   }
2231 }
2232 
xnn_f32_vsqrdiffc_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS (1)])2233 void xnn_f32_vsqrdiffc_ukernel__avx512f_x32(
2234     size_t n,
2235     const float* a,
2236     const float* b,
2237     float* y,
2238     const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
2239 {
2240   assert(n != 0);
2241   assert(n % sizeof(float) == 0);
2242   assert(a != NULL);
2243   assert(b != NULL);
2244   assert(y != NULL);
2245 
2246 
2247   const __m512 vb = _mm512_set1_ps(*b);
2248   for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
2249     const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
2250     const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
2251     a += 32;
2252 
2253     __m512 vy0123456789ABCDEF = _mm512_sub_ps(va0123456789ABCDEF, vb);
2254     __m512 vyGHIJKLMNOPQRSTUV = _mm512_sub_ps(vaGHIJKLMNOPQRSTUV, vb);
2255 
2256     vy0123456789ABCDEF = _mm512_mul_ps(vy0123456789ABCDEF, vy0123456789ABCDEF);
2257     vyGHIJKLMNOPQRSTUV = _mm512_mul_ps(vyGHIJKLMNOPQRSTUV, vyGHIJKLMNOPQRSTUV);
2258 
2259 
2260     _mm512_storeu_ps(y, vy0123456789ABCDEF);
2261     _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
2262     y += 32;
2263   }
2264   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2265     const __m512 va = _mm512_loadu_ps(a);
2266     a += 16;
2267 
2268     __m512 vy = _mm512_sub_ps(va, vb);
2269     vy = _mm512_mul_ps(vy, vy);
2270     _mm512_storeu_ps(y, vy);
2271     y += 16;
2272   }
2273   if XNN_UNLIKELY(n != 0) {
2274     assert(n >= 1 * sizeof(float));
2275     assert(n <= 15 * sizeof(float));
2276     // Prepare mask for valid 32-bit elements (depends on n).
2277     n >>= 2 /* log2(sizeof(float)) */;
2278     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2279 
2280     const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
2281 
2282     __m512 vy = _mm512_sub_ps(va, vb);
2283     vy = _mm512_mul_ps(vy, vy);
2284     _mm512_mask_storeu_ps(y, vmask, vy);
2285   }
2286 }
2287 
xnn_f32_vsub_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2288 void xnn_f32_vsub_minmax_ukernel__avx512f_x32(
2289     size_t n,
2290     const float* a,
2291     const float* b,
2292     float* y,
2293     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
2294 {
2295   assert(n != 0);
2296   assert(n % sizeof(float) == 0);
2297   assert(a != NULL);
2298   assert(b != NULL);
2299   assert(y != NULL);
2300 
2301   const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
2302   const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
2303 
2304   for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
2305     const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
2306     const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
2307     a += 32;
2308 
2309     const __m512 vb0123456789ABCDEF = _mm512_loadu_ps(b);
2310     const __m512 vbGHIJKLMNOPQRSTUV = _mm512_loadu_ps(b + 16);
2311     b += 32;
2312 
2313     __m512 vy0123456789ABCDEF = _mm512_sub_ps(va0123456789ABCDEF, vb0123456789ABCDEF);
2314     __m512 vyGHIJKLMNOPQRSTUV = _mm512_sub_ps(vaGHIJKLMNOPQRSTUV, vbGHIJKLMNOPQRSTUV);
2315 
2316 
2317     vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
2318     vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
2319 
2320     vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
2321     vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
2322 
2323     _mm512_storeu_ps(y, vy0123456789ABCDEF);
2324     _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
2325     y += 32;
2326   }
2327   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2328     const __m512 va = _mm512_loadu_ps(a);
2329     a += 16;
2330 
2331     const __m512 vb = _mm512_loadu_ps(b);
2332     b += 16;
2333 
2334     __m512 vy = _mm512_sub_ps(va, vb);
2335     vy = _mm512_max_ps(vy, vy_min);
2336     vy = _mm512_min_ps(vy, vy_max);
2337     _mm512_storeu_ps(y, vy);
2338     y += 16;
2339   }
2340   if XNN_UNLIKELY(n != 0) {
2341     assert(n >= 1 * sizeof(float));
2342     assert(n <= 15 * sizeof(float));
2343     // Prepare mask for valid 32-bit elements (depends on n).
2344     n >>= 2 /* log2(sizeof(float)) */;
2345     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2346 
2347     const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
2348     const __m512 vb = _mm512_maskz_loadu_ps(vmask, b);
2349 
2350     __m512 vy = _mm512_sub_ps(va, vb);
2351     vy = _mm512_max_ps(vy, vy_min);
2352     vy = _mm512_min_ps(vy, vy_max);
2353     _mm512_mask_storeu_ps(y, vmask, vy);
2354   }
2355 }
2356 
xnn_f32_vsubc_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2357 void xnn_f32_vsubc_minmax_ukernel__avx512f_x32(
2358     size_t n,
2359     const float* a,
2360     const float* b,
2361     float* y,
2362     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
2363 {
2364   assert(n != 0);
2365   assert(n % sizeof(float) == 0);
2366   assert(a != NULL);
2367   assert(b != NULL);
2368   assert(y != NULL);
2369 
2370   const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
2371   const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
2372 
2373   const __m512 vb = _mm512_set1_ps(*b);
2374   for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
2375     const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
2376     const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
2377     a += 32;
2378 
2379     __m512 vy0123456789ABCDEF = _mm512_sub_ps(va0123456789ABCDEF, vb);
2380     __m512 vyGHIJKLMNOPQRSTUV = _mm512_sub_ps(vaGHIJKLMNOPQRSTUV, vb);
2381 
2382 
2383     vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
2384     vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
2385 
2386     vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
2387     vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
2388 
2389     _mm512_storeu_ps(y, vy0123456789ABCDEF);
2390     _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
2391     y += 32;
2392   }
2393   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2394     const __m512 va = _mm512_loadu_ps(a);
2395     a += 16;
2396 
2397     __m512 vy = _mm512_sub_ps(va, vb);
2398     vy = _mm512_max_ps(vy, vy_min);
2399     vy = _mm512_min_ps(vy, vy_max);
2400     _mm512_storeu_ps(y, vy);
2401     y += 16;
2402   }
2403   if XNN_UNLIKELY(n != 0) {
2404     assert(n >= 1 * sizeof(float));
2405     assert(n <= 15 * sizeof(float));
2406     // Prepare mask for valid 32-bit elements (depends on n).
2407     n >>= 2 /* log2(sizeof(float)) */;
2408     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2409 
2410     const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
2411 
2412     __m512 vy = _mm512_sub_ps(va, vb);
2413     vy = _mm512_max_ps(vy, vy_min);
2414     vy = _mm512_min_ps(vy, vy_max);
2415     _mm512_mask_storeu_ps(y, vmask, vy);
2416   }
2417 }
2418 
xnn_f32_vclamp_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2419 void xnn_f32_vclamp_ukernel__avx512f_x16(
2420     size_t n,
2421     const float* x,
2422     float* y,
2423     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
2424 {
2425   assert(n != 0);
2426   assert(n % sizeof(float) == 0);
2427   assert(x != NULL);
2428   assert(y != NULL);
2429 
2430   const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
2431   const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
2432 
2433   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2434     __m512 vacc0123456789ABCDEF = _mm512_loadu_ps(x);
2435     x += 16;
2436 
2437     vacc0123456789ABCDEF = _mm512_max_ps(vacc0123456789ABCDEF, vy_min);
2438 
2439     vacc0123456789ABCDEF = _mm512_min_ps(vacc0123456789ABCDEF, vy_max);
2440 
2441     _mm512_storeu_ps(y, vacc0123456789ABCDEF);
2442     y += 16;
2443   }
2444   if XNN_UNLIKELY(n != 0) {
2445     assert(n >= 1 * sizeof(float));
2446     assert(n <= 15 * sizeof(float));
2447     // Prepare mask for valid 32-bit elements (depends on n).
2448     n >>= 2 /* log2(sizeof(float)) */;
2449     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2450 
2451     __m512 vacc = _mm512_maskz_loadu_ps(vmask, x);
2452     vacc = _mm512_max_ps(vacc, vy_min);
2453     vacc = _mm512_min_ps(vacc, vy_max);
2454     _mm512_mask_storeu_ps(y, vmask, vacc);
2455   }
2456 }
2457 
xnn_f32_velu_ukernel__avx512f_rr1_lut16_p3_perm_x64(size_t n,const float * x,float * y,const union xnn_f32_elu_params params[restrict XNN_MIN_ELEMENTS (1)])2458 void xnn_f32_velu_ukernel__avx512f_rr1_lut16_p3_perm_x64(
2459     size_t n,
2460     const float* x,
2461     float* y,
2462     const union xnn_f32_elu_params params[restrict XNN_MIN_ELEMENTS(1)])
2463 {
2464   assert(n != 0);
2465   assert(n % sizeof(float) == 0);
2466 
2467   const __m512 vprescale = _mm512_set1_ps(params->avx512_rr1_lut16_p3.prescale);
2468   const __m512 valpha = _mm512_set1_ps(params->avx512_rr1_lut16_p3.alpha);
2469   const __m512 vbeta = _mm512_set1_ps(params->avx512_rr1_lut16_p3.beta);
2470   const __m512 vsat_cutoff = _mm512_set1_ps(params->avx512_rr1_lut16_p3.sat_cutoff);
2471   const __m512 vmagic_bias = _mm512_set1_ps(params->avx512_rr1_lut16_p3.magic_bias);
2472   const __m512 vlog2e = _mm512_set1_ps(params->avx512_rr1_lut16_p3.log2e);
2473   const __m512 vminus_ln2 = _mm512_set1_ps(params->avx512_rr1_lut16_p3.minus_ln2);
2474   const __m512 vc3 = _mm512_set1_ps(params->avx512_rr1_lut16_p3.c3);
2475   const __m512 vc2 = _mm512_set1_ps(params->avx512_rr1_lut16_p3.c2);
2476   const __m512i vtable = _mm512_load_si512(params->avx512_rr1_lut16_p3.table);
2477 
2478   for (; n >= 64 * sizeof(float); n -= 64 * sizeof(float)) {
2479     __m512 vx0 = _mm512_loadu_ps(x);
2480     __m512 vx1 = _mm512_loadu_ps(x + 16);
2481     __m512 vx2 = _mm512_loadu_ps(x + 32);
2482     __m512 vx3 = _mm512_loadu_ps(x + 48);
2483     x += 64;
2484 
2485     const __m512 vz0 = _mm512_max_ps(vsat_cutoff, _mm512_mul_ps(vx0, vprescale));
2486     const __m512 vz1 = _mm512_max_ps(vsat_cutoff, _mm512_mul_ps(vx1, vprescale));
2487     const __m512 vz2 = _mm512_max_ps(vsat_cutoff, _mm512_mul_ps(vx2, vprescale));
2488     const __m512 vz3 = _mm512_max_ps(vsat_cutoff, _mm512_mul_ps(vx3, vprescale));
2489 
2490     __m512 vn0 = _mm512_fmadd_ps(vz0, vlog2e, vmagic_bias);
2491     __m512 vn1 = _mm512_fmadd_ps(vz1, vlog2e, vmagic_bias);
2492     __m512 vn2 = _mm512_fmadd_ps(vz2, vlog2e, vmagic_bias);
2493     __m512 vn3 = _mm512_fmadd_ps(vz3, vlog2e, vmagic_bias);
2494 
2495     const __m512i ven0 = _mm512_slli_epi32(_mm512_castps_si512(vn0), 19);
2496     const __m512i vl0 = _mm512_permutexvar_epi32(_mm512_castps_si512(vn0), vtable);
2497     const __m512i ven1 = _mm512_slli_epi32(_mm512_castps_si512(vn1), 19);
2498     const __m512i vl1 = _mm512_permutexvar_epi32(_mm512_castps_si512(vn1), vtable);
2499     const __m512i ven2 = _mm512_slli_epi32(_mm512_castps_si512(vn2), 19);
2500     const __m512i vl2 = _mm512_permutexvar_epi32(_mm512_castps_si512(vn2), vtable);
2501     const __m512i ven3 = _mm512_slli_epi32(_mm512_castps_si512(vn3), 19);
2502     const __m512i vl3 = _mm512_permutexvar_epi32(_mm512_castps_si512(vn3), vtable);
2503 
2504     __m512 vs0 = _mm512_castsi512_ps(_mm512_add_epi32(vl0, ven0));
2505     vn0 = _mm512_sub_ps(vn0, vmagic_bias);
2506     __m512 vs1 = _mm512_castsi512_ps(_mm512_add_epi32(vl1, ven1));
2507     vn1 = _mm512_sub_ps(vn1, vmagic_bias);
2508     __m512 vs2 = _mm512_castsi512_ps(_mm512_add_epi32(vl2, ven2));
2509     vn2 = _mm512_sub_ps(vn2, vmagic_bias);
2510     __m512 vs3 = _mm512_castsi512_ps(_mm512_add_epi32(vl3, ven3));
2511     vn3 = _mm512_sub_ps(vn3, vmagic_bias);
2512 
2513     __m512 vt0 = _mm512_fmadd_ps(vn0, vminus_ln2, vz0);
2514     __m512 vt1 = _mm512_fmadd_ps(vn1, vminus_ln2, vz1);
2515     __m512 vt2 = _mm512_fmadd_ps(vn2, vminus_ln2, vz2);
2516     __m512 vt3 = _mm512_fmadd_ps(vn3, vminus_ln2, vz3);
2517 
2518     __m512 vp0 = _mm512_fmadd_ps(vc3, vt0, vc2);
2519     __m512 vp1 = _mm512_fmadd_ps(vc3, vt1, vc2);
2520     __m512 vp2 = _mm512_fmadd_ps(vc3, vt2, vc2);
2521     __m512 vp3 = _mm512_fmadd_ps(vc3, vt3, vc2);
2522 
2523     vp0 = _mm512_mul_ps(vp0, vt0);
2524     vt0 = _mm512_mul_ps(vt0, vs0);
2525     vp1 = _mm512_mul_ps(vp1, vt1);
2526     vt1 = _mm512_mul_ps(vt1, vs1);
2527     vp2 = _mm512_mul_ps(vp2, vt2);
2528     vt2 = _mm512_mul_ps(vt2, vs2);
2529     vp3 = _mm512_mul_ps(vp3, vt3);
2530     vt3 = _mm512_mul_ps(vt3, vs3);
2531 
2532     vs0 = _mm512_fmsub_ps(vs0, valpha, valpha);
2533     vs1 = _mm512_fmsub_ps(vs1, valpha, valpha);
2534     vs2 = _mm512_fmsub_ps(vs2, valpha, valpha);
2535     vs3 = _mm512_fmsub_ps(vs3, valpha, valpha);
2536 
2537     vp0 = _mm512_fmadd_ps(vp0, vt0, vt0);
2538     vp1 = _mm512_fmadd_ps(vp1, vt1, vt1);
2539     vp2 = _mm512_fmadd_ps(vp2, vt2, vt2);
2540     vp3 = _mm512_fmadd_ps(vp3, vt3, vt3);
2541 
2542     const __m512 vzero = _mm512_setzero_ps();
2543     __m512 vy0 = _mm512_fmadd_ps(vp0, valpha, vs0);
2544     const __mmask16 vsign0 = _mm512_cmp_ps_mask(vx0, vzero, _CMP_NLT_US);
2545     __m512 vy1 = _mm512_fmadd_ps(vp1, valpha, vs1);
2546     const __mmask16 vsign1 = _mm512_cmp_ps_mask(vx1, vzero, _CMP_NLT_US);
2547     __m512 vy2 = _mm512_fmadd_ps(vp2, valpha, vs2);
2548     const __mmask16 vsign2 = _mm512_cmp_ps_mask(vx2, vzero, _CMP_NLT_US);
2549     __m512 vy3 = _mm512_fmadd_ps(vp3, valpha, vs3);
2550     const __mmask16 vsign3 = _mm512_cmp_ps_mask(vx3, vzero, _CMP_NLT_US);
2551 
2552     vy0 = _mm512_mask_mul_ps(vy0, vsign0, vx0, vbeta);
2553     vy1 = _mm512_mask_mul_ps(vy1, vsign1, vx1, vbeta);
2554     vy2 = _mm512_mask_mul_ps(vy2, vsign2, vx2, vbeta);
2555     vy3 = _mm512_mask_mul_ps(vy3, vsign3, vx3, vbeta);
2556 
2557     _mm512_storeu_ps(y, vy0);
2558     _mm512_storeu_ps(y + 16, vy1);
2559     _mm512_storeu_ps(y + 32, vy2);
2560     _mm512_storeu_ps(y + 48, vy3);
2561     y += 64;
2562   }
2563   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2564     __m512 vx = _mm512_loadu_ps(x);
2565     x += 16;
2566 
2567     const __m512 vz = _mm512_max_ps(vsat_cutoff, _mm512_mul_ps(vx, vprescale));
2568     const __mmask16 vsign = _mm512_cmp_ps_mask(vx, _mm512_setzero_ps(), _CMP_NLT_US);
2569 
2570     __m512 vn = _mm512_fmadd_ps(vz, vlog2e, vmagic_bias);
2571     const __m512i ven = _mm512_slli_epi32(_mm512_castps_si512(vn), 19);
2572     const __m512i vl = _mm512_permutexvar_epi32(_mm512_castps_si512(vn), vtable);
2573     __m512 vs = _mm512_castsi512_ps(_mm512_add_epi32(vl, ven));
2574     vn = _mm512_sub_ps(vn, vmagic_bias);
2575 
2576     __m512 vt = _mm512_fmadd_ps(vn, vminus_ln2, vz);
2577 
2578     __m512 vp = _mm512_fmadd_ps(vc3, vt, vc2);
2579     vp = _mm512_mul_ps(vp, vt);
2580 
2581     vt = _mm512_mul_ps(vt, vs);
2582     vs = _mm512_fmsub_ps(vs, valpha, valpha);
2583     vp = _mm512_fmadd_ps(vp, vt, vt);
2584     __m512 vy = _mm512_fmadd_ps(vp, valpha, vs);
2585 
2586     vy = _mm512_mask_mul_ps(vy, vsign, vx, vbeta);
2587 
2588     _mm512_storeu_ps(y, vy);
2589     y += 16;
2590   }
2591   if XNN_UNLIKELY(n != 0) {
2592     assert(n >= 1 * sizeof(float));
2593     assert(n <= 15 * sizeof(float));
2594     // Prepare mask for valid 32-bit elements (depends on n).
2595     n >>= 2 /* log2(sizeof(float)) */;
2596     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2597 
2598     __m512 vx = _mm512_maskz_loadu_ps(vmask, x);
2599 
2600     const __m512 vz = _mm512_max_ps(vsat_cutoff, _mm512_mul_ps(vx, vprescale));
2601     const __mmask16 vsign = _mm512_cmp_ps_mask(vx, _mm512_setzero_ps(), _CMP_NLT_US);
2602 
2603     __m512 vn = _mm512_fmadd_ps(vz, vlog2e, vmagic_bias);
2604     const __m512i ven = _mm512_slli_epi32(_mm512_castps_si512(vn), 19);
2605     const __m512i vl = _mm512_permutexvar_epi32(_mm512_castps_si512(vn), vtable);
2606     __m512 vs = _mm512_castsi512_ps(_mm512_add_epi32(vl, ven));
2607     vn = _mm512_sub_ps(vn, vmagic_bias);
2608 
2609     __m512 vt = _mm512_fmadd_ps(vn, vminus_ln2, vz);
2610 
2611     __m512 vp = _mm512_fmadd_ps(vc3, vt, vc2);
2612     vp = _mm512_mul_ps(vp, vt);
2613 
2614     vt = _mm512_mul_ps(vt, vs);
2615     vs = _mm512_fmsub_ps(vs, valpha, valpha);
2616     vp = _mm512_fmadd_ps(vp, vt, vt);
2617     __m512 vy = _mm512_fmadd_ps(vp, valpha, vs);
2618 
2619     vy = _mm512_mask_mul_ps(vy, vsign, vx, vbeta);
2620 
2621     _mm512_mask_storeu_ps(y, vmask, vy);
2622   }
2623 }
2624 
xnn_f32_vhswish_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_hswish_params params[restrict XNN_MIN_ELEMENTS (1)])2625 void xnn_f32_vhswish_ukernel__avx512f_x16(
2626     size_t n,
2627     const float* x,
2628     float* y,
2629     const union xnn_f32_hswish_params params[restrict XNN_MIN_ELEMENTS(1)])
2630 {
2631   assert(n != 0);
2632   assert(n % sizeof(float) == 0);
2633 
2634   const __m512 vsixth = _mm512_set1_ps(params->avx512.sixth);
2635   const __m512 vhalf = _mm512_set1_ps(params->avx512.half);
2636   const __m512 vone = _mm512_set1_ps(params->avx512.one);
2637   const __m512 vzero = _mm512_setzero_ps();
2638 
2639   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2640     const __m512 vx = _mm512_loadu_ps(x);
2641     x += 16;
2642     __m512 vacc = _mm512_fmadd_ps(vx, vsixth, vhalf);
2643     vacc = _mm512_max_ps(vacc, vzero);
2644     vacc = _mm512_min_ps(vacc, vone);
2645     vacc = _mm512_mul_ps(vacc, vx);
2646     _mm512_storeu_ps(y, vacc);
2647     y += 16;
2648   }
2649   if XNN_UNLIKELY(n != 0) {
2650     assert(n >= 1 * sizeof(float));
2651     assert(n <= 15 * sizeof(float));
2652     // Prepare mask for valid 32-bit elements (depends on n).
2653     n >>= 2 /* log2(sizeof(float)) */;
2654     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2655 
2656     const __m512 vx = _mm512_maskz_loadu_ps(vmask, x);
2657     __m512 vacc = _mm512_fmadd_ps(vx, vsixth, vhalf);
2658     vacc = _mm512_max_ps(vacc, vzero);
2659     vacc = _mm512_min_ps(vacc, vone);
2660     vacc = _mm512_mul_ps(vacc, vx);
2661     _mm512_mask_storeu_ps(y, vmask, vacc);
2662   }
2663 }
2664 
xnn_f32_vlrelu_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_lrelu_params params[restrict XNN_MIN_ELEMENTS (1)])2665 void xnn_f32_vlrelu_ukernel__avx512f_x16(
2666     size_t n,
2667     const float* x,
2668     float* y,
2669     const union xnn_f32_lrelu_params params[restrict XNN_MIN_ELEMENTS(1)])
2670 {
2671   assert(n != 0);
2672   assert(n % sizeof(float) == 0);
2673 
2674   const __m512 vslope = _mm512_set1_ps(params->scalar.slope);
2675   const __m512 vzero = _mm512_setzero_ps();
2676 
2677   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2678     __m512 vacc0123456789ABCDEF = _mm512_loadu_ps(x);
2679     x += 16;
2680 
2681     const __mmask16 vsign0123456789ABCDEF = _mm512_cmp_ps_mask(vacc0123456789ABCDEF, vzero, _CMP_LT_OQ);
2682 
2683     vacc0123456789ABCDEF = _mm512_mask_mul_ps(vacc0123456789ABCDEF, vsign0123456789ABCDEF, vacc0123456789ABCDEF, vslope);
2684 
2685     _mm512_storeu_ps(y, vacc0123456789ABCDEF);
2686     y += 16;
2687   }
2688   if XNN_UNLIKELY(n != 0) {
2689     assert(n >= 1 * sizeof(float));
2690     assert(n <= 15 * sizeof(float));
2691     // Prepare mask for valid 32-bit elements (depends on n).
2692     n >>= 2 /* log2(sizeof(float)) */;
2693     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2694 
2695     __m512 vacc = _mm512_maskz_loadu_ps(vmask, x);
2696     const __mmask16 vsign = _mm512_mask_cmp_ps_mask(vmask, vacc, vzero, _CMP_LT_OQ);
2697     vacc = _mm512_mask_mul_ps(vacc, vsign, vacc, vslope);
2698     _mm512_mask_storeu_ps(y, vmask, vacc);
2699   }
2700 }
2701 
xnn_f32_vrndd_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_rnd_params params[restrict XNN_MIN_ELEMENTS (1)])2702 void xnn_f32_vrndd_ukernel__avx512f_x16(
2703     size_t n,
2704     const float* x,
2705     float* y,
2706     const union xnn_f32_rnd_params params[restrict XNN_MIN_ELEMENTS(1)])
2707 {
2708   assert(n != 0);
2709   assert(n % sizeof(float) == 0);
2710 
2711   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2712     const __m512 vx0123456789ABCDEF = _mm512_loadu_ps(x);
2713     x += 16;
2714 
2715     const __m512 vy0123456789ABCDEF = _mm512_roundscale_ps(vx0123456789ABCDEF, _MM_FROUND_TO_NEG_INF);
2716 
2717     _mm512_storeu_ps(y, vy0123456789ABCDEF);
2718     y += 16;
2719   }
2720   if XNN_UNLIKELY(n != 0) {
2721     assert(n >= 1 * sizeof(float));
2722     assert(n <= 15 * sizeof(float));
2723     // Prepare mask for valid 32-bit elements (depends on n).
2724     n >>= 2 /* log2(sizeof(float)) */;
2725     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2726 
2727     const __m512 vx = _mm512_maskz_loadu_ps(vmask, x);
2728     const __m512 vy = _mm512_maskz_roundscale_ps(vmask, vx, _MM_FROUND_TO_NEG_INF);
2729     _mm512_mask_storeu_ps(y, vmask, vy);
2730   }
2731 }
2732 
xnn_f32_vrndne_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_rnd_params params[restrict XNN_MIN_ELEMENTS (1)])2733 void xnn_f32_vrndne_ukernel__avx512f_x16(
2734     size_t n,
2735     const float* x,
2736     float* y,
2737     const union xnn_f32_rnd_params params[restrict XNN_MIN_ELEMENTS(1)])
2738 {
2739   assert(n != 0);
2740   assert(n % sizeof(float) == 0);
2741 
2742   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2743     const __m512 vx0123456789ABCDEF = _mm512_loadu_ps(x);
2744     x += 16;
2745 
2746     const __m512 vy0123456789ABCDEF = _mm512_roundscale_ps(vx0123456789ABCDEF, _MM_FROUND_TO_NEAREST_INT);
2747 
2748     _mm512_storeu_ps(y, vy0123456789ABCDEF);
2749     y += 16;
2750   }
2751   if XNN_UNLIKELY(n != 0) {
2752     assert(n >= 1 * sizeof(float));
2753     assert(n <= 15 * sizeof(float));
2754     // Prepare mask for valid 32-bit elements (depends on n).
2755     n >>= 2 /* log2(sizeof(float)) */;
2756     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2757 
2758     const __m512 vx = _mm512_maskz_loadu_ps(vmask, x);
2759     const __m512 vy = _mm512_maskz_roundscale_ps(vmask, vx, _MM_FROUND_TO_NEAREST_INT);
2760     _mm512_mask_storeu_ps(y, vmask, vy);
2761   }
2762 }
2763 
xnn_f32_vrndu_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_rnd_params params[restrict XNN_MIN_ELEMENTS (1)])2764 void xnn_f32_vrndu_ukernel__avx512f_x16(
2765     size_t n,
2766     const float* x,
2767     float* y,
2768     const union xnn_f32_rnd_params params[restrict XNN_MIN_ELEMENTS(1)])
2769 {
2770   assert(n != 0);
2771   assert(n % sizeof(float) == 0);
2772 
2773   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2774     const __m512 vx0123456789ABCDEF = _mm512_loadu_ps(x);
2775     x += 16;
2776 
2777     const __m512 vy0123456789ABCDEF = _mm512_roundscale_ps(vx0123456789ABCDEF, _MM_FROUND_TO_POS_INF);
2778 
2779     _mm512_storeu_ps(y, vy0123456789ABCDEF);
2780     y += 16;
2781   }
2782   if XNN_UNLIKELY(n != 0) {
2783     assert(n >= 1 * sizeof(float));
2784     assert(n <= 15 * sizeof(float));
2785     // Prepare mask for valid 32-bit elements (depends on n).
2786     n >>= 2 /* log2(sizeof(float)) */;
2787     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2788 
2789     const __m512 vx = _mm512_maskz_loadu_ps(vmask, x);
2790     const __m512 vy = _mm512_maskz_roundscale_ps(vmask, vx, _MM_FROUND_TO_POS_INF);
2791     _mm512_mask_storeu_ps(y, vmask, vy);
2792   }
2793 }
2794 
xnn_f32_vrndz_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_rnd_params params[restrict XNN_MIN_ELEMENTS (1)])2795 void xnn_f32_vrndz_ukernel__avx512f_x16(
2796     size_t n,
2797     const float* x,
2798     float* y,
2799     const union xnn_f32_rnd_params params[restrict XNN_MIN_ELEMENTS(1)])
2800 {
2801   assert(n != 0);
2802   assert(n % sizeof(float) == 0);
2803 
2804   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2805     const __m512 vx0123456789ABCDEF = _mm512_loadu_ps(x);
2806     x += 16;
2807 
2808     const __m512 vy0123456789ABCDEF = _mm512_roundscale_ps(vx0123456789ABCDEF, _MM_FROUND_TO_ZERO);
2809 
2810     _mm512_storeu_ps(y, vy0123456789ABCDEF);
2811     y += 16;
2812   }
2813   if XNN_UNLIKELY(n != 0) {
2814     assert(n >= 1 * sizeof(float));
2815     assert(n <= 15 * sizeof(float));
2816     // Prepare mask for valid 32-bit elements (depends on n).
2817     n >>= 2 /* log2(sizeof(float)) */;
2818     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2819 
2820     const __m512 vx = _mm512_maskz_loadu_ps(vmask, x);
2821     const __m512 vy = _mm512_maskz_roundscale_ps(vmask, vx, _MM_FROUND_TO_ZERO);
2822     _mm512_mask_storeu_ps(y, vmask, vy);
2823   }
2824 }
2825 
xnn_f32_vsigmoid_ukernel__avx512f_rr2_lut32_p2_perm2_scalef_div_x64(size_t n,const float * x,float * y,const union xnn_f32_sigmoid_params params[restrict XNN_MIN_ELEMENTS (1)])2826 void xnn_f32_vsigmoid_ukernel__avx512f_rr2_lut32_p2_perm2_scalef_div_x64(
2827     size_t n,
2828     const float* x,
2829     float* y,
2830     const union xnn_f32_sigmoid_params params[restrict XNN_MIN_ELEMENTS(1)])
2831 {
2832   assert(n % sizeof(float) == 0);
2833 
2834   const __m512i vsign_mask = _mm512_set1_epi32((int) params->avx512_rr2_lut32_p2.sign_mask);
2835   const __m512 vmagic_bias = _mm512_set1_ps(params->avx512_rr2_lut32_p2.magic_bias);
2836   const __m512 vlog2e = _mm512_set1_ps(params->avx512_rr2_lut32_p2.log2e);
2837   const __m512 vtable_lo = _mm512_load_ps(params->avx512_rr2_lut32_p2.table_lo);
2838   const __m512 vtable_hi = _mm512_load_ps(params->avx512_rr2_lut32_p2.table_hi);
2839   const __m512 vminus_ln2_hi = _mm512_set1_ps(params->avx512_rr2_lut32_p2.minus_ln2_hi);
2840   const __m512 vminus_ln2_lo = _mm512_set1_ps(params->avx512_rr2_lut32_p2.minus_ln2_lo);
2841   const __m512 vc2 = _mm512_set1_ps(params->avx512_rr2_lut32_p2.c2);
2842   const __m512 vc1 = _mm512_set1_ps(params->avx512_rr2_lut32_p2.c1);
2843   const __m512 vone = _mm512_set1_ps(params->avx512_rr2_lut32_p2.one);
2844 
2845   for (; n >= 64 * sizeof(float); n -= 64 * sizeof(float)) {
2846     const __m512 vx0 = _mm512_loadu_ps(x);
2847     const __m512 vx1 = _mm512_loadu_ps(x + 16);
2848     const __m512 vx2 = _mm512_loadu_ps(x + 32);
2849     const __m512 vx3 = _mm512_loadu_ps(x + 48);
2850     x += 64;
2851 
2852     const __m512 vz0 = _mm512_castsi512_ps(_mm512_or_epi32(_mm512_castps_si512(vx0), vsign_mask));
2853     const __m512 vz1 = _mm512_castsi512_ps(_mm512_or_epi32(_mm512_castps_si512(vx1), vsign_mask));
2854     const __m512 vz2 = _mm512_castsi512_ps(_mm512_or_epi32(_mm512_castps_si512(vx2), vsign_mask));
2855     const __m512 vz3 = _mm512_castsi512_ps(_mm512_or_epi32(_mm512_castps_si512(vx3), vsign_mask));
2856 
2857     __m512 vn0 = _mm512_fmadd_ps(vz0, vlog2e, vmagic_bias);
2858     __m512 vn1 = _mm512_fmadd_ps(vz1, vlog2e, vmagic_bias);
2859     __m512 vn2 = _mm512_fmadd_ps(vz2, vlog2e, vmagic_bias);
2860     __m512 vn3 = _mm512_fmadd_ps(vz3, vlog2e, vmagic_bias);
2861 
2862     const __m512 vl0 = _mm512_permutex2var_ps(vtable_lo, _mm512_castps_si512(vn0), vtable_hi);
2863     const __m512 vl1 = _mm512_permutex2var_ps(vtable_lo, _mm512_castps_si512(vn1), vtable_hi);
2864     const __m512 vl2 = _mm512_permutex2var_ps(vtable_lo, _mm512_castps_si512(vn2), vtable_hi);
2865     const __m512 vl3 = _mm512_permutex2var_ps(vtable_lo, _mm512_castps_si512(vn3), vtable_hi);
2866 
2867     vn0 = _mm512_sub_ps(vn0, vmagic_bias);
2868     vn1 = _mm512_sub_ps(vn1, vmagic_bias);
2869     vn2 = _mm512_sub_ps(vn2, vmagic_bias);
2870     vn3 = _mm512_sub_ps(vn3, vmagic_bias);
2871 
2872     __m512 vt0 = _mm512_fmadd_ps(vn0, vminus_ln2_hi, vz0);
2873     __m512 vt1 = _mm512_fmadd_ps(vn1, vminus_ln2_hi, vz1);
2874     __m512 vt2 = _mm512_fmadd_ps(vn2, vminus_ln2_hi, vz2);
2875     __m512 vt3 = _mm512_fmadd_ps(vn3, vminus_ln2_hi, vz3);
2876 
2877     vt0 = _mm512_fmadd_ps(vn0, vminus_ln2_lo, vt0);
2878     vt1 = _mm512_fmadd_ps(vn1, vminus_ln2_lo, vt1);
2879     vt2 = _mm512_fmadd_ps(vn2, vminus_ln2_lo, vt2);
2880     vt3 = _mm512_fmadd_ps(vn3, vminus_ln2_lo, vt3);
2881 
2882     __m512 vp0 = _mm512_fmadd_ps(vt0, vc2, vc1);
2883     __m512 vp1 = _mm512_fmadd_ps(vt1, vc2, vc1);
2884     __m512 vp2 = _mm512_fmadd_ps(vt2, vc2, vc1);
2885     __m512 vp3 = _mm512_fmadd_ps(vt3, vc2, vc1);
2886 
2887     vt0 = _mm512_mul_ps(vt0, vl0);
2888     vt1 = _mm512_mul_ps(vt1, vl1);
2889     vt2 = _mm512_mul_ps(vt2, vl2);
2890     vt3 = _mm512_mul_ps(vt3, vl3);
2891 
2892     vp0 = _mm512_fmadd_ps(vt0, vp0, vl0);
2893     vp1 = _mm512_fmadd_ps(vt1, vp1, vl1);
2894     vp2 = _mm512_fmadd_ps(vt2, vp2, vl2);
2895     vp3 = _mm512_fmadd_ps(vt3, vp3, vl3);
2896 
2897     const __m512 ve0 = _mm512_scalef_ps(vp0, vn0);
2898     const __m512 ve1 = _mm512_scalef_ps(vp1, vn1);
2899     const __m512 ve2 = _mm512_scalef_ps(vp2, vn2);
2900     const __m512 ve3 = _mm512_scalef_ps(vp3, vn3);
2901 
2902     const __m512 vd0 = _mm512_add_ps(ve0, vone);
2903     const __m512 vd1 = _mm512_add_ps(ve1, vone);
2904     const __m512 vd2 = _mm512_add_ps(ve2, vone);
2905     const __m512 vd3 = _mm512_add_ps(ve3, vone);
2906 
2907     __m512 vf0 = _mm512_div_ps(ve0, vd0);
2908     __m512 vf1 = _mm512_div_ps(ve1, vd1);
2909     __m512 vf2 = _mm512_div_ps(ve2, vd2);
2910     __m512 vf3 = _mm512_div_ps(ve3, vd3);
2911 
2912     vf0 = _mm512_mask_sub_ps(vf0, _mm512_testn_epi32_mask(_mm512_castps_si512(vx0), vsign_mask), vone, vf0);
2913     vf1 = _mm512_mask_sub_ps(vf1, _mm512_testn_epi32_mask(_mm512_castps_si512(vx1), vsign_mask), vone, vf1);
2914     vf2 = _mm512_mask_sub_ps(vf2, _mm512_testn_epi32_mask(_mm512_castps_si512(vx2), vsign_mask), vone, vf2);
2915     vf3 = _mm512_mask_sub_ps(vf3, _mm512_testn_epi32_mask(_mm512_castps_si512(vx3), vsign_mask), vone, vf3);
2916 
2917     _mm512_storeu_ps(y, vf0);
2918     _mm512_storeu_ps(y + 16, vf1);
2919     _mm512_storeu_ps(y + 32, vf2);
2920     _mm512_storeu_ps(y + 48, vf3);
2921     y += 64;
2922   }
2923   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2924     const __m512 vx = _mm512_loadu_ps(x);
2925     x += 16;
2926 
2927     const __m512 vz = _mm512_castsi512_ps(_mm512_or_epi32(_mm512_castps_si512(vx), vsign_mask));
2928 
2929     __m512 vn = _mm512_fmadd_ps(vz, vlog2e, vmagic_bias);
2930     const __m512 vl = _mm512_permutex2var_ps(vtable_lo, _mm512_castps_si512(vn), vtable_hi);
2931     vn = _mm512_sub_ps(vn, vmagic_bias);
2932 
2933     __m512 vt = _mm512_fmadd_ps(vn, vminus_ln2_hi, vz);
2934     vt = _mm512_fmadd_ps(vn, vminus_ln2_lo, vt);
2935 
2936     __m512 vp = _mm512_fmadd_ps(vt, vc2, vc1);
2937     vt = _mm512_mul_ps(vt, vl);
2938     vp = _mm512_fmadd_ps(vt, vp, vl);
2939 
2940     const __m512 ve = _mm512_scalef_ps(vp, vn);
2941     const __m512 vd = _mm512_add_ps(ve, vone);
2942 
2943     __m512 vf = _mm512_div_ps(ve, vd);
2944 
2945     vf = _mm512_mask_sub_ps(vf, _mm512_testn_epi32_mask(_mm512_castps_si512(vx), vsign_mask), vone, vf);
2946 
2947     _mm512_storeu_ps(y, vf);
2948     y += 16;
2949   }
2950   if XNN_UNLIKELY(n != 0) {
2951     assert(n >= 1 * sizeof(float));
2952     assert(n <= 15 * sizeof(float));
2953 
2954     // Prepare mask for valid 32-bit elements (depends on n).
2955     n >>= 2 /* log2(sizeof(float)) */;
2956     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2957 
2958     const __m512 vx = _mm512_maskz_loadu_ps(vmask, x);
2959     const __m512 vz = _mm512_castsi512_ps(_mm512_or_epi32(_mm512_castps_si512(vx), vsign_mask));
2960 
2961     __m512 vn = _mm512_fmadd_ps(vz, vlog2e, vmagic_bias);
2962     const __m512 vl = _mm512_permutex2var_ps(vtable_lo, _mm512_castps_si512(vn), vtable_hi);
2963     vn = _mm512_sub_ps(vn, vmagic_bias);
2964 
2965     __m512 vt = _mm512_fmadd_ps(vn, vminus_ln2_hi, vz);
2966     vt = _mm512_fmadd_ps(vn, vminus_ln2_lo, vt);
2967 
2968     __m512 vp = _mm512_fmadd_ps(vt, vc2, vc1);
2969     vt = _mm512_mul_ps(vt, vl);
2970     vp = _mm512_fmadd_ps(vt, vp, vl);
2971 
2972     const __m512 ve = _mm512_scalef_ps(vp, vn);
2973     const __m512 vd = _mm512_add_ps(ve, vone);
2974 
2975     __m512 vf = _mm512_div_ps(ve, vd);
2976 
2977     vf = _mm512_mask_sub_ps(vf, _mm512_testn_epi32_mask(_mm512_castps_si512(vx), vsign_mask), vone, vf);
2978 
2979     _mm512_mask_storeu_ps(y, vmask, vf);
2980   }
2981 }
2982 
xnn_f32_vabs_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_abs_params params[restrict XNN_MIN_ELEMENTS (1)])2983 void xnn_f32_vabs_ukernel__avx512f_x16(
2984     size_t n,
2985     const float* x,
2986     float* y,
2987     const union xnn_f32_abs_params params[restrict XNN_MIN_ELEMENTS(1)])
2988 {
2989   assert(n != 0);
2990   assert(n % sizeof(float) == 0);
2991   assert(x != NULL);
2992   assert(y != NULL);
2993 
2994   const __m512i vnonsign_mask = _mm512_set1_epi32((int) params->avx512.nonsign_mask);
2995   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2996     const __m512i vx0123456789ABCDEF = _mm512_loadu_si512(x);
2997     x += 16;
2998 
2999     const __m512i vy0123456789ABCDEF = _mm512_and_epi32(vx0123456789ABCDEF, vnonsign_mask);
3000 
3001     _mm512_storeu_si512(y, vy0123456789ABCDEF);
3002     y += 16;
3003   }
3004   if XNN_UNLIKELY(n != 0) {
3005     assert(n >= 1 * sizeof(float));
3006     assert(n <= 15 * sizeof(float));
3007     // Prepare mask for valid 32-bit elements (depends on n).
3008     n >>= 2 /* log2(sizeof(float)) */;
3009     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
3010 
3011     const __m512i vx = _mm512_maskz_loadu_epi32(vmask, x);
3012     const __m512i vy = _mm512_and_epi32(vx, vnonsign_mask);
3013     _mm512_mask_storeu_epi32(y, vmask, vy);
3014   }
3015 }
3016 
xnn_f32_vneg_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_neg_params params[restrict XNN_MIN_ELEMENTS (1)])3017 void xnn_f32_vneg_ukernel__avx512f_x16(
3018     size_t n,
3019     const float* x,
3020     float* y,
3021     const union xnn_f32_neg_params params[restrict XNN_MIN_ELEMENTS(1)])
3022 {
3023   assert(n != 0);
3024   assert(n % sizeof(float) == 0);
3025   assert(x != NULL);
3026   assert(y != NULL);
3027 
3028   const __m512i vsign_mask = _mm512_set1_epi32((int) params->avx512.sign_mask);
3029   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
3030     const __m512i vx0123456789ABCDEF = _mm512_loadu_si512(x);
3031     x += 16;
3032 
3033     const __m512i vy0123456789ABCDEF = _mm512_xor_epi32(vx0123456789ABCDEF, vsign_mask);
3034 
3035     _mm512_storeu_si512(y, vy0123456789ABCDEF);
3036     y += 16;
3037   }
3038   if XNN_UNLIKELY(n != 0) {
3039     assert(n >= 1 * sizeof(float));
3040     assert(n <= 15 * sizeof(float));
3041     // Prepare mask for valid 32-bit elements (depends on n).
3042     n >>= 2 /* log2(sizeof(float)) */;
3043     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
3044 
3045     const __m512i vx = _mm512_maskz_loadu_epi32(vmask, x);
3046     const __m512i vy = _mm512_xor_epi32(vx, vsign_mask);
3047     _mm512_mask_storeu_epi32(y, vmask, vy);
3048   }
3049 }
3050 
xnn_f32_vsqr_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS (1)])3051 void xnn_f32_vsqr_ukernel__avx512f_x16(
3052     size_t n,
3053     const float* x,
3054     float* y,
3055     const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
3056 {
3057   assert(n != 0);
3058   assert(n % sizeof(float) == 0);
3059   assert(x != NULL);
3060   assert(y != NULL);
3061 
3062   for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
3063     const __m512 vx0123456789ABCDEF = _mm512_loadu_ps(x);
3064     x += 16;
3065 
3066     const __m512 vy0123456789ABCDEF = _mm512_mul_ps(vx0123456789ABCDEF, vx0123456789ABCDEF);
3067 
3068     _mm512_storeu_ps(y, vy0123456789ABCDEF);
3069     y += 16;
3070   }
3071   if XNN_UNLIKELY(n != 0) {
3072     assert(n >= 1 * sizeof(float));
3073     assert(n <= 15 * sizeof(float));
3074     // Prepare mask for valid 32-bit elements (depends on n).
3075     n >>= 2 /* log2(sizeof(float)) */;
3076     const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
3077 
3078     const __m512 vx = _mm512_maskz_loadu_ps(vmask, x);
3079     const __m512 vy = _mm512_mul_ps(vx, vx);
3080     _mm512_mask_storeu_ps(y, vmask, vy);
3081   }
3082 }
3083