1 // Copyright 2021 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <immintrin.h>
9
10 #include <xnnpack/common.h>
11 #include <xnnpack/dwconv.h>
12 #include <xnnpack/gemm.h>
13 #include <xnnpack/igemm.h>
14 #include <xnnpack/intrinsics-polyfill.h>
15 #include <xnnpack/math.h>
16 #include <xnnpack/prelu.h>
17 #include <xnnpack/vbinary.h>
18 #include <xnnpack/vunary.h>
19
20
xnn_f32_dwconv_minmax_ukernel_up16x25__avx512f(size_t channels,size_t output_width,const float ** input,const float * weights,float * output,size_t input_stride,size_t output_increment,size_t input_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])21 void xnn_f32_dwconv_minmax_ukernel_up16x25__avx512f(
22 size_t channels,
23 size_t output_width,
24 const float** input,
25 const float* weights,
26 float* output,
27 size_t input_stride,
28 size_t output_increment,
29 size_t input_offset,
30 const float* zero,
31 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
32 {
33 assert(channels != 0);
34 assert(output_width != 0);
35
36 const __m512 vmax = _mm512_set1_ps(params->scalar.max);
37 const __m512 vmin = _mm512_set1_ps(params->scalar.min);
38 do {
39 const float* i0 = input[0];
40 assert(i0 != NULL);
41 if XNN_UNPREDICTABLE(i0 != zero) {
42 i0 = (const float*) ((uintptr_t) i0 + input_offset);
43 }
44 const float* i1 = input[1];
45 assert(i1 != NULL);
46 if XNN_UNPREDICTABLE(i1 != zero) {
47 i1 = (const float*) ((uintptr_t) i1 + input_offset);
48 }
49 const float* i2 = input[2];
50 assert(i2 != NULL);
51 if XNN_UNPREDICTABLE(i2 != zero) {
52 i2 = (const float*) ((uintptr_t) i2 + input_offset);
53 }
54 const float* i3 = input[3];
55 assert(i3 != NULL);
56 if XNN_UNPREDICTABLE(i3 != zero) {
57 i3 = (const float*) ((uintptr_t) i3 + input_offset);
58 }
59 const float* i4 = input[4];
60 assert(i4 != NULL);
61 if XNN_UNPREDICTABLE(i4 != zero) {
62 i4 = (const float*) ((uintptr_t) i4 + input_offset);
63 }
64 const float* i5 = input[5];
65 assert(i5 != NULL);
66 if XNN_UNPREDICTABLE(i5 != zero) {
67 i5 = (const float*) ((uintptr_t) i5 + input_offset);
68 }
69 const float* i6 = input[6];
70 assert(i6 != NULL);
71 if XNN_UNPREDICTABLE(i6 != zero) {
72 i6 = (const float*) ((uintptr_t) i6 + input_offset);
73 }
74 const float* i7 = input[7];
75 assert(i7 != NULL);
76 if XNN_UNPREDICTABLE(i7 != zero) {
77 i7 = (const float*) ((uintptr_t) i7 + input_offset);
78 }
79 const float* i8 = input[8];
80 assert(i8 != NULL);
81 if XNN_UNPREDICTABLE(i8 != zero) {
82 i8 = (const float*) ((uintptr_t) i8 + input_offset);
83 }
84 const float* i9 = input[9];
85 assert(i9 != NULL);
86 if XNN_UNPREDICTABLE(i9 != zero) {
87 i9 = (const float*) ((uintptr_t) i9 + input_offset);
88 }
89 const float* i10 = input[10];
90 assert(i10 != NULL);
91 if XNN_UNPREDICTABLE(i10 != zero) {
92 i10 = (const float*) ((uintptr_t) i10 + input_offset);
93 }
94 const float* i11 = input[11];
95 assert(i11 != NULL);
96 if XNN_UNPREDICTABLE(i11 != zero) {
97 i11 = (const float*) ((uintptr_t) i11 + input_offset);
98 }
99 const float* i12 = input[12];
100 assert(i12 != NULL);
101 if XNN_UNPREDICTABLE(i12 != zero) {
102 i12 = (const float*) ((uintptr_t) i12 + input_offset);
103 }
104 const float* i13 = input[13];
105 assert(i13 != NULL);
106 if XNN_UNPREDICTABLE(i13 != zero) {
107 i13 = (const float*) ((uintptr_t) i13 + input_offset);
108 }
109 const float* i14 = input[14];
110 assert(i14 != NULL);
111 if XNN_UNPREDICTABLE(i14 != zero) {
112 i14 = (const float*) ((uintptr_t) i14 + input_offset);
113 }
114 const float* i15 = input[15];
115 assert(i15 != NULL);
116 if XNN_UNPREDICTABLE(i15 != zero) {
117 i15 = (const float*) ((uintptr_t) i15 + input_offset);
118 }
119 const float* i16 = input[16];
120 assert(i16 != NULL);
121 if XNN_UNPREDICTABLE(i16 != zero) {
122 i16 = (const float*) ((uintptr_t) i16 + input_offset);
123 }
124 const float* i17 = input[17];
125 assert(i17 != NULL);
126 if XNN_UNPREDICTABLE(i17 != zero) {
127 i17 = (const float*) ((uintptr_t) i17 + input_offset);
128 }
129 const float* i18 = input[18];
130 assert(i18 != NULL);
131 if XNN_UNPREDICTABLE(i18 != zero) {
132 i18 = (const float*) ((uintptr_t) i18 + input_offset);
133 }
134 const float* i19 = input[19];
135 assert(i19 != NULL);
136 if XNN_UNPREDICTABLE(i19 != zero) {
137 i19 = (const float*) ((uintptr_t) i19 + input_offset);
138 }
139 const float* i20 = input[20];
140 assert(i20 != NULL);
141 if XNN_UNPREDICTABLE(i20 != zero) {
142 i20 = (const float*) ((uintptr_t) i20 + input_offset);
143 }
144 const float* i21 = input[21];
145 assert(i21 != NULL);
146 if XNN_UNPREDICTABLE(i21 != zero) {
147 i21 = (const float*) ((uintptr_t) i21 + input_offset);
148 }
149 const float* i22 = input[22];
150 assert(i22 != NULL);
151 if XNN_UNPREDICTABLE(i22 != zero) {
152 i22 = (const float*) ((uintptr_t) i22 + input_offset);
153 }
154 const float* i23 = input[23];
155 assert(i23 != NULL);
156 if XNN_UNPREDICTABLE(i23 != zero) {
157 i23 = (const float*) ((uintptr_t) i23 + input_offset);
158 }
159 const float* i24 = input[24];
160 assert(i24 != NULL);
161 if XNN_UNPREDICTABLE(i24 != zero) {
162 i24 = (const float*) ((uintptr_t) i24 + input_offset);
163 }
164 input = (const float**) ((uintptr_t) input + input_stride);
165
166 size_t c = channels;
167 const float* w = weights;
168 for (; c >= 16; c -= 16) {
169 __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
170
171
172 const __m512 vi0x0123456789ABCDEF = _mm512_loadu_ps(i0);
173 i0 += 16;
174
175 const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16);
176 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
177
178 const __m512 vi1x0123456789ABCDEF = _mm512_loadu_ps(i1);
179 i1 += 16;
180
181 const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32);
182 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
183
184 const __m512 vi2x0123456789ABCDEF = _mm512_loadu_ps(i2);
185 i2 += 16;
186
187 const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48);
188 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
189
190 const __m512 vi3x0123456789ABCDEF = _mm512_loadu_ps(i3);
191 i3 += 16;
192
193 const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 64);
194 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0);
195
196 const __m512 vi4x0123456789ABCDEF = _mm512_loadu_ps(i4);
197 i4 += 16;
198
199 const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 80);
200 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0);
201
202 const __m512 vi5x0123456789ABCDEF = _mm512_loadu_ps(i5);
203 i5 += 16;
204
205 const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 96);
206 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp0);
207
208 const __m512 vi6x0123456789ABCDEF = _mm512_loadu_ps(i6);
209 i6 += 16;
210
211 const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 112);
212 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0);
213
214 const __m512 vi7x0123456789ABCDEF = _mm512_loadu_ps(i7);
215 i7 += 16;
216
217 const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 128);
218 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp0);
219
220 const __m512 vi8x0123456789ABCDEF = _mm512_loadu_ps(i8);
221 i8 += 16;
222
223 const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 144);
224 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0);
225
226 const __m512 vi9x0123456789ABCDEF = _mm512_loadu_ps(i9);
227 i9 += 16;
228
229 const __m512 vk9x0123456789ABCDEF = _mm512_load_ps(w + 160);
230 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF, vacc0123456789ABCDEFp0);
231
232 const __m512 vi10x0123456789ABCDEF = _mm512_loadu_ps(i10);
233 i10 += 16;
234
235 const __m512 vk10x0123456789ABCDEF = _mm512_load_ps(w + 176);
236 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF, vacc0123456789ABCDEFp0);
237
238 const __m512 vi11x0123456789ABCDEF = _mm512_loadu_ps(i11);
239 i11 += 16;
240
241 const __m512 vk11x0123456789ABCDEF = _mm512_load_ps(w + 192);
242 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF, vacc0123456789ABCDEFp0);
243
244 const __m512 vi12x0123456789ABCDEF = _mm512_loadu_ps(i12);
245 i12 += 16;
246
247 const __m512 vk12x0123456789ABCDEF = _mm512_load_ps(w + 208);
248 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF, vacc0123456789ABCDEFp0);
249
250 const __m512 vi13x0123456789ABCDEF = _mm512_loadu_ps(i13);
251 i13 += 16;
252
253 const __m512 vk13x0123456789ABCDEF = _mm512_load_ps(w + 224);
254 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF, vacc0123456789ABCDEFp0);
255
256 const __m512 vi14x0123456789ABCDEF = _mm512_loadu_ps(i14);
257 i14 += 16;
258
259 const __m512 vk14x0123456789ABCDEF = _mm512_load_ps(w + 240);
260 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF, vacc0123456789ABCDEFp0);
261
262 const __m512 vi15x0123456789ABCDEF = _mm512_loadu_ps(i15);
263 i15 += 16;
264
265 const __m512 vk15x0123456789ABCDEF = _mm512_load_ps(w + 256);
266 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF, vacc0123456789ABCDEFp0);
267
268 const __m512 vi16x0123456789ABCDEF = _mm512_loadu_ps(i16);
269 i16 += 16;
270
271 const __m512 vk16x0123456789ABCDEF = _mm512_load_ps(w + 272);
272 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF, vacc0123456789ABCDEFp0);
273
274 const __m512 vi17x0123456789ABCDEF = _mm512_loadu_ps(i17);
275 i17 += 16;
276
277 const __m512 vk17x0123456789ABCDEF = _mm512_load_ps(w + 288);
278 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF, vacc0123456789ABCDEFp0);
279
280 const __m512 vi18x0123456789ABCDEF = _mm512_loadu_ps(i18);
281 i18 += 16;
282
283 const __m512 vk18x0123456789ABCDEF = _mm512_load_ps(w + 304);
284 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF, vacc0123456789ABCDEFp0);
285
286 const __m512 vi19x0123456789ABCDEF = _mm512_loadu_ps(i19);
287 i19 += 16;
288
289 const __m512 vk19x0123456789ABCDEF = _mm512_load_ps(w + 320);
290 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF, vacc0123456789ABCDEFp0);
291
292 const __m512 vi20x0123456789ABCDEF = _mm512_loadu_ps(i20);
293 i20 += 16;
294
295 const __m512 vk20x0123456789ABCDEF = _mm512_load_ps(w + 336);
296 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF, vacc0123456789ABCDEFp0);
297
298 const __m512 vi21x0123456789ABCDEF = _mm512_loadu_ps(i21);
299 i21 += 16;
300
301 const __m512 vk21x0123456789ABCDEF = _mm512_load_ps(w + 352);
302 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF, vacc0123456789ABCDEFp0);
303
304 const __m512 vi22x0123456789ABCDEF = _mm512_loadu_ps(i22);
305 i22 += 16;
306
307 const __m512 vk22x0123456789ABCDEF = _mm512_load_ps(w + 368);
308 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF, vacc0123456789ABCDEFp0);
309
310 const __m512 vi23x0123456789ABCDEF = _mm512_loadu_ps(i23);
311 i23 += 16;
312
313 const __m512 vk23x0123456789ABCDEF = _mm512_load_ps(w + 384);
314 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF, vacc0123456789ABCDEFp0);
315
316 const __m512 vi24x0123456789ABCDEF = _mm512_loadu_ps(i24);
317 i24 += 16;
318
319 const __m512 vk24x0123456789ABCDEF = _mm512_load_ps(w + 400);
320 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF, vacc0123456789ABCDEFp0);
321
322 w += 416;
323
324
325 __m512 vacc0123456789ABCDEF = _mm512_max_ps(vacc0123456789ABCDEFp0, vmin);
326 vacc0123456789ABCDEF = _mm512_min_ps(vacc0123456789ABCDEF, vmax);
327
328 _mm512_storeu_ps(output, vacc0123456789ABCDEF);
329 output += 16;
330 }
331 if XNN_UNLIKELY(c != 0) {
332 assert(c >= 1);
333 assert(c <= 16);
334 // Prepare mask for valid 32-bit elements (depends on nc).
335 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
336
337 __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
338
339 const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
340 const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16);
341 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
342
343 const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
344 const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
345 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
346
347 const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
348 const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48);
349 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
350
351 const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3);
352 const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64);
353 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0);
354
355 const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4);
356 const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 80);
357 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0);
358
359 const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5);
360 const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96);
361 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp0);
362
363 const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6);
364 const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 112);
365 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0);
366
367 const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7);
368 const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128);
369 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp0);
370
371 const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8);
372 const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 144);
373 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0);
374
375 const __m512 vi9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i9);
376 const __m512 vk9x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 160);
377 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF, vacc0123456789ABCDEFp0);
378
379 const __m512 vi10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i10);
380 const __m512 vk10x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 176);
381 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF, vacc0123456789ABCDEFp0);
382
383 const __m512 vi11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i11);
384 const __m512 vk11x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 192);
385 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF, vacc0123456789ABCDEFp0);
386
387 const __m512 vi12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i12);
388 const __m512 vk12x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 208);
389 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF, vacc0123456789ABCDEFp0);
390
391 const __m512 vi13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i13);
392 const __m512 vk13x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 224);
393 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF, vacc0123456789ABCDEFp0);
394
395 const __m512 vi14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i14);
396 const __m512 vk14x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 240);
397 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF, vacc0123456789ABCDEFp0);
398
399 const __m512 vi15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i15);
400 const __m512 vk15x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 256);
401 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF, vacc0123456789ABCDEFp0);
402
403 const __m512 vi16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i16);
404 const __m512 vk16x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 272);
405 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF, vacc0123456789ABCDEFp0);
406
407 const __m512 vi17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i17);
408 const __m512 vk17x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 288);
409 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF, vacc0123456789ABCDEFp0);
410
411 const __m512 vi18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i18);
412 const __m512 vk18x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 304);
413 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF, vacc0123456789ABCDEFp0);
414
415 const __m512 vi19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i19);
416 const __m512 vk19x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 320);
417 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF, vacc0123456789ABCDEFp0);
418
419 const __m512 vi20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i20);
420 const __m512 vk20x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 336);
421 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF, vacc0123456789ABCDEFp0);
422
423 const __m512 vi21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i21);
424 const __m512 vk21x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 352);
425 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF, vacc0123456789ABCDEFp0);
426
427 const __m512 vi22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i22);
428 const __m512 vk22x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 368);
429 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF, vacc0123456789ABCDEFp0);
430
431 const __m512 vi23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i23);
432 const __m512 vk23x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 384);
433 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF, vacc0123456789ABCDEFp0);
434
435 const __m512 vi24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i24);
436 const __m512 vk24x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 400);
437 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF, vacc0123456789ABCDEFp0);
438
439
440 __m512 vacc0123456789ABCDEF = _mm512_max_ps(vacc0123456789ABCDEFp0, vmin);
441 vacc0123456789ABCDEF = _mm512_min_ps(vacc0123456789ABCDEF, vmax);
442
443 _mm512_mask_storeu_ps(output, vmask, vacc0123456789ABCDEF);
444 output += c;
445 }
446
447 output = (float*) ((uintptr_t) output + output_increment);
448 } while (--output_width != 0);
449 }
450
xnn_f32_dwconv_minmax_ukernel_up16x3__avx512f(size_t channels,size_t output_width,const float ** input,const float * weights,float * output,size_t input_stride,size_t output_increment,size_t input_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])451 void xnn_f32_dwconv_minmax_ukernel_up16x3__avx512f(
452 size_t channels,
453 size_t output_width,
454 const float** input,
455 const float* weights,
456 float* output,
457 size_t input_stride,
458 size_t output_increment,
459 size_t input_offset,
460 const float* zero,
461 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
462 {
463 assert(channels != 0);
464 assert(output_width != 0);
465
466 const __m512 vmax = _mm512_set1_ps(params->scalar.max);
467 const __m512 vmin = _mm512_set1_ps(params->scalar.min);
468 do {
469 const float* i0 = input[0];
470 assert(i0 != NULL);
471 if XNN_UNPREDICTABLE(i0 != zero) {
472 i0 = (const float*) ((uintptr_t) i0 + input_offset);
473 }
474 const float* i1 = input[1];
475 assert(i1 != NULL);
476 if XNN_UNPREDICTABLE(i1 != zero) {
477 i1 = (const float*) ((uintptr_t) i1 + input_offset);
478 }
479 const float* i2 = input[2];
480 assert(i2 != NULL);
481 if XNN_UNPREDICTABLE(i2 != zero) {
482 i2 = (const float*) ((uintptr_t) i2 + input_offset);
483 }
484 input = (const float**) ((uintptr_t) input + input_stride);
485
486 size_t c = channels;
487 const float* w = weights;
488 for (; c >= 16; c -= 16) {
489 __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
490
491
492 const __m512 vi0x0123456789ABCDEF = _mm512_loadu_ps(i0);
493 i0 += 16;
494
495 const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16);
496 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
497
498 const __m512 vi1x0123456789ABCDEF = _mm512_loadu_ps(i1);
499 i1 += 16;
500
501 const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32);
502 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
503
504 const __m512 vi2x0123456789ABCDEF = _mm512_loadu_ps(i2);
505 i2 += 16;
506
507 const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48);
508 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
509
510 w += 64;
511
512
513 __m512 vacc0123456789ABCDEF = _mm512_max_ps(vacc0123456789ABCDEFp0, vmin);
514 vacc0123456789ABCDEF = _mm512_min_ps(vacc0123456789ABCDEF, vmax);
515
516 _mm512_storeu_ps(output, vacc0123456789ABCDEF);
517 output += 16;
518 }
519 if XNN_UNLIKELY(c != 0) {
520 assert(c >= 1);
521 assert(c <= 16);
522 // Prepare mask for valid 32-bit elements (depends on nc).
523 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
524
525 __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
526
527 const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
528 const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16);
529 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
530
531 const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
532 const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
533 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
534
535 const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
536 const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48);
537 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
538
539
540 __m512 vacc0123456789ABCDEF = _mm512_max_ps(vacc0123456789ABCDEFp0, vmin);
541 vacc0123456789ABCDEF = _mm512_min_ps(vacc0123456789ABCDEF, vmax);
542
543 _mm512_mask_storeu_ps(output, vmask, vacc0123456789ABCDEF);
544 output += c;
545 }
546
547 output = (float*) ((uintptr_t) output + output_increment);
548 } while (--output_width != 0);
549 }
550
xnn_f32_dwconv_minmax_ukernel_up16x4__avx512f(size_t channels,size_t output_width,const float ** input,const float * weights,float * output,size_t input_stride,size_t output_increment,size_t input_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])551 void xnn_f32_dwconv_minmax_ukernel_up16x4__avx512f(
552 size_t channels,
553 size_t output_width,
554 const float** input,
555 const float* weights,
556 float* output,
557 size_t input_stride,
558 size_t output_increment,
559 size_t input_offset,
560 const float* zero,
561 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
562 {
563 assert(channels != 0);
564 assert(output_width != 0);
565
566 const __m512 vmax = _mm512_set1_ps(params->scalar.max);
567 const __m512 vmin = _mm512_set1_ps(params->scalar.min);
568 do {
569 const float* i0 = input[0];
570 assert(i0 != NULL);
571 if XNN_UNPREDICTABLE(i0 != zero) {
572 i0 = (const float*) ((uintptr_t) i0 + input_offset);
573 }
574 const float* i1 = input[1];
575 assert(i1 != NULL);
576 if XNN_UNPREDICTABLE(i1 != zero) {
577 i1 = (const float*) ((uintptr_t) i1 + input_offset);
578 }
579 const float* i2 = input[2];
580 assert(i2 != NULL);
581 if XNN_UNPREDICTABLE(i2 != zero) {
582 i2 = (const float*) ((uintptr_t) i2 + input_offset);
583 }
584 const float* i3 = input[3];
585 assert(i3 != NULL);
586 if XNN_UNPREDICTABLE(i3 != zero) {
587 i3 = (const float*) ((uintptr_t) i3 + input_offset);
588 }
589 input = (const float**) ((uintptr_t) input + input_stride);
590
591 size_t c = channels;
592 const float* w = weights;
593 for (; c >= 16; c -= 16) {
594 __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
595
596
597 const __m512 vi0x0123456789ABCDEF = _mm512_loadu_ps(i0);
598 i0 += 16;
599
600 const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16);
601 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
602
603 const __m512 vi1x0123456789ABCDEF = _mm512_loadu_ps(i1);
604 i1 += 16;
605
606 const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32);
607 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
608
609 const __m512 vi2x0123456789ABCDEF = _mm512_loadu_ps(i2);
610 i2 += 16;
611
612 const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48);
613 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
614
615 const __m512 vi3x0123456789ABCDEF = _mm512_loadu_ps(i3);
616 i3 += 16;
617
618 const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 64);
619 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0);
620
621 w += 80;
622
623
624 __m512 vacc0123456789ABCDEF = _mm512_max_ps(vacc0123456789ABCDEFp0, vmin);
625 vacc0123456789ABCDEF = _mm512_min_ps(vacc0123456789ABCDEF, vmax);
626
627 _mm512_storeu_ps(output, vacc0123456789ABCDEF);
628 output += 16;
629 }
630 if XNN_UNLIKELY(c != 0) {
631 assert(c >= 1);
632 assert(c <= 16);
633 // Prepare mask for valid 32-bit elements (depends on nc).
634 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
635
636 __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
637
638 const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
639 const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16);
640 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
641
642 const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
643 const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
644 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
645
646 const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
647 const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48);
648 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
649
650 const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3);
651 const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64);
652 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0);
653
654
655 __m512 vacc0123456789ABCDEF = _mm512_max_ps(vacc0123456789ABCDEFp0, vmin);
656 vacc0123456789ABCDEF = _mm512_min_ps(vacc0123456789ABCDEF, vmax);
657
658 _mm512_mask_storeu_ps(output, vmask, vacc0123456789ABCDEF);
659 output += c;
660 }
661
662 output = (float*) ((uintptr_t) output + output_increment);
663 } while (--output_width != 0);
664 }
665
xnn_f32_dwconv_minmax_ukernel_up16x9__avx512f(size_t channels,size_t output_width,const float ** input,const float * weights,float * output,size_t input_stride,size_t output_increment,size_t input_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])666 void xnn_f32_dwconv_minmax_ukernel_up16x9__avx512f(
667 size_t channels,
668 size_t output_width,
669 const float** input,
670 const float* weights,
671 float* output,
672 size_t input_stride,
673 size_t output_increment,
674 size_t input_offset,
675 const float* zero,
676 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
677 {
678 assert(channels != 0);
679 assert(output_width != 0);
680
681 const __m512 vmax = _mm512_set1_ps(params->scalar.max);
682 const __m512 vmin = _mm512_set1_ps(params->scalar.min);
683 do {
684 const float* i0 = input[0];
685 assert(i0 != NULL);
686 if XNN_UNPREDICTABLE(i0 != zero) {
687 i0 = (const float*) ((uintptr_t) i0 + input_offset);
688 }
689 const float* i1 = input[1];
690 assert(i1 != NULL);
691 if XNN_UNPREDICTABLE(i1 != zero) {
692 i1 = (const float*) ((uintptr_t) i1 + input_offset);
693 }
694 const float* i2 = input[2];
695 assert(i2 != NULL);
696 if XNN_UNPREDICTABLE(i2 != zero) {
697 i2 = (const float*) ((uintptr_t) i2 + input_offset);
698 }
699 const float* i3 = input[3];
700 assert(i3 != NULL);
701 if XNN_UNPREDICTABLE(i3 != zero) {
702 i3 = (const float*) ((uintptr_t) i3 + input_offset);
703 }
704 const float* i4 = input[4];
705 assert(i4 != NULL);
706 if XNN_UNPREDICTABLE(i4 != zero) {
707 i4 = (const float*) ((uintptr_t) i4 + input_offset);
708 }
709 const float* i5 = input[5];
710 assert(i5 != NULL);
711 if XNN_UNPREDICTABLE(i5 != zero) {
712 i5 = (const float*) ((uintptr_t) i5 + input_offset);
713 }
714 const float* i6 = input[6];
715 assert(i6 != NULL);
716 if XNN_UNPREDICTABLE(i6 != zero) {
717 i6 = (const float*) ((uintptr_t) i6 + input_offset);
718 }
719 const float* i7 = input[7];
720 assert(i7 != NULL);
721 if XNN_UNPREDICTABLE(i7 != zero) {
722 i7 = (const float*) ((uintptr_t) i7 + input_offset);
723 }
724 const float* i8 = input[8];
725 assert(i8 != NULL);
726 if XNN_UNPREDICTABLE(i8 != zero) {
727 i8 = (const float*) ((uintptr_t) i8 + input_offset);
728 }
729 input = (const float**) ((uintptr_t) input + input_stride);
730
731 size_t c = channels;
732 const float* w = weights;
733 for (; c >= 16; c -= 16) {
734 __m512 vacc0123456789ABCDEFp0 = _mm512_load_ps(w);
735
736
737 const __m512 vi0x0123456789ABCDEF = _mm512_loadu_ps(i0);
738 i0 += 16;
739
740 const __m512 vk0x0123456789ABCDEF = _mm512_load_ps(w + 16);
741 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
742
743 const __m512 vi1x0123456789ABCDEF = _mm512_loadu_ps(i1);
744 i1 += 16;
745
746 const __m512 vk1x0123456789ABCDEF = _mm512_load_ps(w + 32);
747 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
748
749 const __m512 vi2x0123456789ABCDEF = _mm512_loadu_ps(i2);
750 i2 += 16;
751
752 const __m512 vk2x0123456789ABCDEF = _mm512_load_ps(w + 48);
753 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
754
755 const __m512 vi3x0123456789ABCDEF = _mm512_loadu_ps(i3);
756 i3 += 16;
757
758 const __m512 vk3x0123456789ABCDEF = _mm512_load_ps(w + 64);
759 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0);
760
761 const __m512 vi4x0123456789ABCDEF = _mm512_loadu_ps(i4);
762 i4 += 16;
763
764 const __m512 vk4x0123456789ABCDEF = _mm512_load_ps(w + 80);
765 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0);
766
767 const __m512 vi5x0123456789ABCDEF = _mm512_loadu_ps(i5);
768 i5 += 16;
769
770 const __m512 vk5x0123456789ABCDEF = _mm512_load_ps(w + 96);
771 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp0);
772
773 const __m512 vi6x0123456789ABCDEF = _mm512_loadu_ps(i6);
774 i6 += 16;
775
776 const __m512 vk6x0123456789ABCDEF = _mm512_load_ps(w + 112);
777 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0);
778
779 const __m512 vi7x0123456789ABCDEF = _mm512_loadu_ps(i7);
780 i7 += 16;
781
782 const __m512 vk7x0123456789ABCDEF = _mm512_load_ps(w + 128);
783 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp0);
784
785 const __m512 vi8x0123456789ABCDEF = _mm512_loadu_ps(i8);
786 i8 += 16;
787
788 const __m512 vk8x0123456789ABCDEF = _mm512_load_ps(w + 144);
789 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0);
790
791 w += 160;
792
793
794 __m512 vacc0123456789ABCDEF = _mm512_max_ps(vacc0123456789ABCDEFp0, vmin);
795 vacc0123456789ABCDEF = _mm512_min_ps(vacc0123456789ABCDEF, vmax);
796
797 _mm512_storeu_ps(output, vacc0123456789ABCDEF);
798 output += 16;
799 }
800 if XNN_UNLIKELY(c != 0) {
801 assert(c >= 1);
802 assert(c <= 16);
803 // Prepare mask for valid 32-bit elements (depends on nc).
804 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
805
806 __m512 vacc0123456789ABCDEFp0 = _mm512_maskz_loadu_ps(vmask, w);
807
808 const __m512 vi0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i0);
809 const __m512 vk0x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 16);
810 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF, vacc0123456789ABCDEFp0);
811
812 const __m512 vi1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i1);
813 const __m512 vk1x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 32);
814 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF, vacc0123456789ABCDEFp0);
815
816 const __m512 vi2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i2);
817 const __m512 vk2x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 48);
818 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF, vacc0123456789ABCDEFp0);
819
820 const __m512 vi3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i3);
821 const __m512 vk3x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 64);
822 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF, vacc0123456789ABCDEFp0);
823
824 const __m512 vi4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i4);
825 const __m512 vk4x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 80);
826 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF, vacc0123456789ABCDEFp0);
827
828 const __m512 vi5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i5);
829 const __m512 vk5x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 96);
830 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF, vacc0123456789ABCDEFp0);
831
832 const __m512 vi6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i6);
833 const __m512 vk6x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 112);
834 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF, vacc0123456789ABCDEFp0);
835
836 const __m512 vi7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i7);
837 const __m512 vk7x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 128);
838 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF, vacc0123456789ABCDEFp0);
839
840 const __m512 vi8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, i8);
841 const __m512 vk8x0123456789ABCDEF = _mm512_maskz_loadu_ps(vmask, w + 144);
842 vacc0123456789ABCDEFp0 = _mm512_fmadd_ps(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF, vacc0123456789ABCDEFp0);
843
844
845 __m512 vacc0123456789ABCDEF = _mm512_max_ps(vacc0123456789ABCDEFp0, vmin);
846 vacc0123456789ABCDEF = _mm512_min_ps(vacc0123456789ABCDEF, vmax);
847
848 _mm512_mask_storeu_ps(output, vmask, vacc0123456789ABCDEF);
849 output += c;
850 }
851
852 output = (float*) ((uintptr_t) output + output_increment);
853 } while (--output_width != 0);
854 }
855
xnn_f32_gemm_minmax_ukernel_1x16__avx512f_broadcast(size_t mr,size_t nc,size_t kc,const float * restrict a,size_t a_stride,const float * restrict w,float * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])856 void xnn_f32_gemm_minmax_ukernel_1x16__avx512f_broadcast(
857 size_t mr,
858 size_t nc,
859 size_t kc,
860 const float*restrict a,
861 size_t a_stride,
862 const float*restrict w,
863 float*restrict c,
864 size_t cm_stride,
865 size_t cn_stride,
866 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
867 {
868 assert(mr != 0);
869 assert(mr <= 1);
870 assert(nc != 0);
871 assert(kc != 0);
872 assert(kc % sizeof(float) == 0);
873 assert(a != NULL);
874 assert(w != NULL);
875 assert(c != NULL);
876
877 const float* a0 = a;
878 float* c0 = c;
879
880 do {
881 __m512 vacc0x0123456789ABCDEF = _mm512_load_ps(w);
882 w += 16;
883
884 size_t k = kc;
885 do {
886 const __m512 vb0123456789ABCDEF = _mm512_load_ps(w);
887 w += 16;
888
889 const __m512 va0 = _mm512_set1_ps(*a0);
890 vacc0x0123456789ABCDEF = _mm512_fmadd_ps(va0, vb0123456789ABCDEF, vacc0x0123456789ABCDEF);
891
892 a0 += 1;
893
894 k -= sizeof(float);
895 } while (k != 0);
896
897 const __m512 vmin = _mm512_set1_ps(params->scalar.min);
898 vacc0x0123456789ABCDEF = _mm512_max_ps(vacc0x0123456789ABCDEF, vmin);
899
900 const __m512 vmax = _mm512_set1_ps(params->scalar.max);
901 vacc0x0123456789ABCDEF = _mm512_min_ps(vacc0x0123456789ABCDEF, vmax);
902
903 if XNN_LIKELY(nc >= 16) {
904 _mm512_storeu_ps(c0, vacc0x0123456789ABCDEF);
905 c0 = (float*) ((uintptr_t) c0 + cn_stride);
906
907 a0 = (const float*) ((uintptr_t) a0 - kc);
908
909 nc -= 16;
910 } else {
911 if (nc & 15) {
912 // Prepare mask for valid 32-bit elements (depends on nc).
913 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << nc) - UINT32_C(1)));
914
915 _mm512_mask_storeu_ps(c0, vmask, vacc0x0123456789ABCDEF);
916 }
917
918 nc = 0;
919 }
920 } while (nc != 0);
921 }
922
xnn_f32_gemm_minmax_ukernel_7x16__avx512f_broadcast(size_t mr,size_t nc,size_t kc,const float * restrict a,size_t a_stride,const float * restrict w,float * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])923 void xnn_f32_gemm_minmax_ukernel_7x16__avx512f_broadcast(
924 size_t mr,
925 size_t nc,
926 size_t kc,
927 const float*restrict a,
928 size_t a_stride,
929 const float*restrict w,
930 float*restrict c,
931 size_t cm_stride,
932 size_t cn_stride,
933 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
934 {
935 assert(mr != 0);
936 assert(mr <= 7);
937 assert(nc != 0);
938 assert(kc != 0);
939 assert(kc % sizeof(float) == 0);
940 assert(a != NULL);
941 assert(w != NULL);
942 assert(c != NULL);
943
944 const float* a0 = a;
945 float* c0 = c;
946 const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
947 float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
948 if XNN_UNPREDICTABLE(mr < 2) {
949 a1 = a0;
950 c1 = c0;
951 }
952 const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
953 float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
954 if XNN_UNPREDICTABLE(mr <= 2) {
955 a2 = a1;
956 c2 = c1;
957 }
958 const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
959 float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
960 if XNN_UNPREDICTABLE(mr < 4) {
961 a3 = a2;
962 c3 = c2;
963 }
964 const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
965 float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
966 if XNN_UNPREDICTABLE(mr <= 4) {
967 a4 = a3;
968 c4 = c3;
969 }
970 const float* a5 = (const float*) ((uintptr_t) a4 + a_stride);
971 float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
972 if XNN_UNPREDICTABLE(mr < 6) {
973 a5 = a4;
974 c5 = c4;
975 }
976 const float* a6 = (const float*) ((uintptr_t) a5 + a_stride);
977 float* c6 = (float*) ((uintptr_t) c5 + cm_stride);
978 if XNN_UNPREDICTABLE(mr <= 6) {
979 a6 = a5;
980 c6 = c5;
981 }
982
983 do {
984 __m512 vacc0x0123456789ABCDEF = _mm512_load_ps(w);
985 __m512 vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF;
986 __m512 vacc2x0123456789ABCDEF = vacc0x0123456789ABCDEF;
987 __m512 vacc3x0123456789ABCDEF = vacc0x0123456789ABCDEF;
988 __m512 vacc4x0123456789ABCDEF = vacc0x0123456789ABCDEF;
989 __m512 vacc5x0123456789ABCDEF = vacc0x0123456789ABCDEF;
990 __m512 vacc6x0123456789ABCDEF = vacc0x0123456789ABCDEF;
991 w += 16;
992
993 size_t k = kc;
994 do {
995 const __m512 vb0123456789ABCDEF = _mm512_load_ps(w);
996 w += 16;
997
998 const __m512 va0 = _mm512_set1_ps(*a0);
999 vacc0x0123456789ABCDEF = _mm512_fmadd_ps(va0, vb0123456789ABCDEF, vacc0x0123456789ABCDEF);
1000 const __m512 va1 = _mm512_set1_ps(*a1);
1001 vacc1x0123456789ABCDEF = _mm512_fmadd_ps(va1, vb0123456789ABCDEF, vacc1x0123456789ABCDEF);
1002 const __m512 va2 = _mm512_set1_ps(*a2);
1003 vacc2x0123456789ABCDEF = _mm512_fmadd_ps(va2, vb0123456789ABCDEF, vacc2x0123456789ABCDEF);
1004 const __m512 va3 = _mm512_set1_ps(*a3);
1005 vacc3x0123456789ABCDEF = _mm512_fmadd_ps(va3, vb0123456789ABCDEF, vacc3x0123456789ABCDEF);
1006 const __m512 va4 = _mm512_set1_ps(*a4);
1007 vacc4x0123456789ABCDEF = _mm512_fmadd_ps(va4, vb0123456789ABCDEF, vacc4x0123456789ABCDEF);
1008 const __m512 va5 = _mm512_set1_ps(*a5);
1009 vacc5x0123456789ABCDEF = _mm512_fmadd_ps(va5, vb0123456789ABCDEF, vacc5x0123456789ABCDEF);
1010 const __m512 va6 = _mm512_set1_ps(*a6);
1011 vacc6x0123456789ABCDEF = _mm512_fmadd_ps(va6, vb0123456789ABCDEF, vacc6x0123456789ABCDEF);
1012
1013 a0 += 1;
1014 a1 += 1;
1015 a2 += 1;
1016 a3 += 1;
1017 a4 += 1;
1018 a5 += 1;
1019 a6 += 1;
1020
1021 k -= sizeof(float);
1022 } while (k != 0);
1023
1024 const __m512 vmin = _mm512_set1_ps(params->scalar.min);
1025 vacc0x0123456789ABCDEF = _mm512_max_ps(vacc0x0123456789ABCDEF, vmin);
1026 vacc1x0123456789ABCDEF = _mm512_max_ps(vacc1x0123456789ABCDEF, vmin);
1027 vacc2x0123456789ABCDEF = _mm512_max_ps(vacc2x0123456789ABCDEF, vmin);
1028 vacc3x0123456789ABCDEF = _mm512_max_ps(vacc3x0123456789ABCDEF, vmin);
1029 vacc4x0123456789ABCDEF = _mm512_max_ps(vacc4x0123456789ABCDEF, vmin);
1030 vacc5x0123456789ABCDEF = _mm512_max_ps(vacc5x0123456789ABCDEF, vmin);
1031 vacc6x0123456789ABCDEF = _mm512_max_ps(vacc6x0123456789ABCDEF, vmin);
1032
1033 const __m512 vmax = _mm512_set1_ps(params->scalar.max);
1034 vacc0x0123456789ABCDEF = _mm512_min_ps(vacc0x0123456789ABCDEF, vmax);
1035 vacc1x0123456789ABCDEF = _mm512_min_ps(vacc1x0123456789ABCDEF, vmax);
1036 vacc2x0123456789ABCDEF = _mm512_min_ps(vacc2x0123456789ABCDEF, vmax);
1037 vacc3x0123456789ABCDEF = _mm512_min_ps(vacc3x0123456789ABCDEF, vmax);
1038 vacc4x0123456789ABCDEF = _mm512_min_ps(vacc4x0123456789ABCDEF, vmax);
1039 vacc5x0123456789ABCDEF = _mm512_min_ps(vacc5x0123456789ABCDEF, vmax);
1040 vacc6x0123456789ABCDEF = _mm512_min_ps(vacc6x0123456789ABCDEF, vmax);
1041
1042 if XNN_LIKELY(nc >= 16) {
1043 _mm512_storeu_ps(c6, vacc6x0123456789ABCDEF);
1044 c6 = (float*) ((uintptr_t) c6 + cn_stride);
1045 _mm512_storeu_ps(c5, vacc5x0123456789ABCDEF);
1046 c5 = (float*) ((uintptr_t) c5 + cn_stride);
1047 _mm512_storeu_ps(c4, vacc4x0123456789ABCDEF);
1048 c4 = (float*) ((uintptr_t) c4 + cn_stride);
1049 _mm512_storeu_ps(c3, vacc3x0123456789ABCDEF);
1050 c3 = (float*) ((uintptr_t) c3 + cn_stride);
1051 _mm512_storeu_ps(c2, vacc2x0123456789ABCDEF);
1052 c2 = (float*) ((uintptr_t) c2 + cn_stride);
1053 _mm512_storeu_ps(c1, vacc1x0123456789ABCDEF);
1054 c1 = (float*) ((uintptr_t) c1 + cn_stride);
1055 _mm512_storeu_ps(c0, vacc0x0123456789ABCDEF);
1056 c0 = (float*) ((uintptr_t) c0 + cn_stride);
1057
1058 a6 = (const float*) ((uintptr_t) a6 - kc);
1059 a5 = (const float*) ((uintptr_t) a5 - kc);
1060 a4 = (const float*) ((uintptr_t) a4 - kc);
1061 a3 = (const float*) ((uintptr_t) a3 - kc);
1062 a2 = (const float*) ((uintptr_t) a2 - kc);
1063 a1 = (const float*) ((uintptr_t) a1 - kc);
1064 a0 = (const float*) ((uintptr_t) a0 - kc);
1065
1066 nc -= 16;
1067 } else {
1068 if (nc & 15) {
1069 // Prepare mask for valid 32-bit elements (depends on nc).
1070 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << nc) - UINT32_C(1)));
1071
1072 _mm512_mask_storeu_ps(c6, vmask, vacc6x0123456789ABCDEF);
1073 _mm512_mask_storeu_ps(c5, vmask, vacc5x0123456789ABCDEF);
1074 _mm512_mask_storeu_ps(c4, vmask, vacc4x0123456789ABCDEF);
1075 _mm512_mask_storeu_ps(c3, vmask, vacc3x0123456789ABCDEF);
1076 _mm512_mask_storeu_ps(c2, vmask, vacc2x0123456789ABCDEF);
1077 _mm512_mask_storeu_ps(c1, vmask, vacc1x0123456789ABCDEF);
1078 _mm512_mask_storeu_ps(c0, vmask, vacc0x0123456789ABCDEF);
1079 }
1080
1081 nc = 0;
1082 }
1083 } while (nc != 0);
1084 }
1085
xnn_f32_igemm_minmax_ukernel_1x16__avx512f_broadcast(size_t mr,size_t nc,size_t kc,size_t ks,const float ** restrict a,const float * restrict w,float * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1086 void xnn_f32_igemm_minmax_ukernel_1x16__avx512f_broadcast(
1087 size_t mr,
1088 size_t nc,
1089 size_t kc,
1090 size_t ks,
1091 const float**restrict a,
1092 const float*restrict w,
1093 float*restrict c,
1094 size_t cm_stride,
1095 size_t cn_stride,
1096 size_t a_offset,
1097 const float* zero,
1098 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
1099 {
1100 assert(mr != 0);
1101 assert(mr <= 1);
1102 assert(nc != 0);
1103 assert(kc != 0);
1104 assert(kc % sizeof(float) == 0);
1105 assert(ks != 0);
1106 assert(ks % (1 * sizeof(void*)) == 0);
1107 assert(a_offset % sizeof(float) == 0);
1108 assert(a != NULL);
1109 assert(w != NULL);
1110 assert(c != NULL);
1111
1112 float* c0 = c;
1113
1114 do {
1115 __m512 vacc0x0123456789ABCDEF = _mm512_load_ps(w);
1116 w += 16;
1117
1118 size_t p = ks;
1119 do {
1120 const float* restrict a0 = a[0];
1121 assert(a0 != NULL);
1122 if XNN_UNPREDICTABLE(a0 != zero) {
1123 a0 = (const float*) ((uintptr_t) a0 + a_offset);
1124 }
1125 a += 1;
1126
1127 size_t k = kc;
1128 do {
1129 const __m512 vb0123456789ABCDEF = _mm512_load_ps(w);
1130 w += 16;
1131
1132 const __m512 va0 = _mm512_set1_ps(*a0);
1133 vacc0x0123456789ABCDEF = _mm512_fmadd_ps(va0, vb0123456789ABCDEF, vacc0x0123456789ABCDEF);
1134
1135 a0 += 1;
1136
1137 k -= sizeof(float);
1138 } while (k != 0);
1139 p -= 1 * sizeof(void*);
1140 } while (p != 0);
1141
1142 const __m512 vmin = _mm512_set1_ps(params->scalar.min);
1143 vacc0x0123456789ABCDEF = _mm512_max_ps(vacc0x0123456789ABCDEF, vmin);
1144
1145 const __m512 vmax = _mm512_set1_ps(params->scalar.max);
1146 vacc0x0123456789ABCDEF = _mm512_min_ps(vacc0x0123456789ABCDEF, vmax);
1147
1148 if XNN_LIKELY(nc >= 16) {
1149 _mm512_storeu_ps(c0, vacc0x0123456789ABCDEF);
1150 c0 = (float*) ((uintptr_t) c0 + cn_stride);
1151
1152 a = (const float**restrict) ((uintptr_t) a - ks);
1153 nc -= 16;
1154 } else {
1155 if (nc & 15) {
1156 // Prepare mask for valid 32-bit elements (depends on nc).
1157 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << nc) - UINT32_C(1)));
1158
1159 _mm512_mask_storeu_ps(c0, vmask, vacc0x0123456789ABCDEF);
1160 }
1161
1162 nc = 0;
1163 }
1164 } while (nc != 0);
1165 }
1166
xnn_f32_igemm_minmax_ukernel_7x16__avx512f_broadcast(size_t mr,size_t nc,size_t kc,size_t ks,const float ** restrict a,const float * restrict w,float * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1167 void xnn_f32_igemm_minmax_ukernel_7x16__avx512f_broadcast(
1168 size_t mr,
1169 size_t nc,
1170 size_t kc,
1171 size_t ks,
1172 const float**restrict a,
1173 const float*restrict w,
1174 float*restrict c,
1175 size_t cm_stride,
1176 size_t cn_stride,
1177 size_t a_offset,
1178 const float* zero,
1179 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
1180 {
1181 assert(mr != 0);
1182 assert(mr <= 7);
1183 assert(nc != 0);
1184 assert(kc != 0);
1185 assert(kc % sizeof(float) == 0);
1186 assert(ks != 0);
1187 assert(ks % (7 * sizeof(void*)) == 0);
1188 assert(a_offset % sizeof(float) == 0);
1189 assert(a != NULL);
1190 assert(w != NULL);
1191 assert(c != NULL);
1192
1193 float* c0 = c;
1194 float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
1195 if XNN_UNPREDICTABLE(mr < 2) {
1196 c1 = c0;
1197 }
1198 float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
1199 if XNN_UNPREDICTABLE(mr <= 2) {
1200 c2 = c1;
1201 }
1202 float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
1203 if XNN_UNPREDICTABLE(mr < 4) {
1204 c3 = c2;
1205 }
1206 float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
1207 if XNN_UNPREDICTABLE(mr <= 4) {
1208 c4 = c3;
1209 }
1210 float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
1211 if XNN_UNPREDICTABLE(mr < 6) {
1212 c5 = c4;
1213 }
1214 float* c6 = (float*) ((uintptr_t) c5 + cm_stride);
1215 if XNN_UNPREDICTABLE(mr <= 6) {
1216 c6 = c5;
1217 }
1218
1219 do {
1220 __m512 vacc0x0123456789ABCDEF = _mm512_load_ps(w);
1221 __m512 vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF;
1222 __m512 vacc2x0123456789ABCDEF = vacc0x0123456789ABCDEF;
1223 __m512 vacc3x0123456789ABCDEF = vacc0x0123456789ABCDEF;
1224 __m512 vacc4x0123456789ABCDEF = vacc0x0123456789ABCDEF;
1225 __m512 vacc5x0123456789ABCDEF = vacc0x0123456789ABCDEF;
1226 __m512 vacc6x0123456789ABCDEF = vacc0x0123456789ABCDEF;
1227 w += 16;
1228
1229 size_t p = ks;
1230 do {
1231 const float* restrict a0 = a[0];
1232 assert(a0 != NULL);
1233 if XNN_UNPREDICTABLE(a0 != zero) {
1234 a0 = (const float*) ((uintptr_t) a0 + a_offset);
1235 }
1236 const float* restrict a1 = a[1];
1237 assert(a1 != NULL);
1238 if XNN_UNPREDICTABLE(a1 != zero) {
1239 a1 = (const float*) ((uintptr_t) a1 + a_offset);
1240 }
1241 const float* restrict a2 = a[2];
1242 assert(a2 != NULL);
1243 if XNN_UNPREDICTABLE(a2 != zero) {
1244 a2 = (const float*) ((uintptr_t) a2 + a_offset);
1245 }
1246 const float* restrict a3 = a[3];
1247 assert(a3 != NULL);
1248 if XNN_UNPREDICTABLE(a3 != zero) {
1249 a3 = (const float*) ((uintptr_t) a3 + a_offset);
1250 }
1251 const float* restrict a4 = a[4];
1252 assert(a4 != NULL);
1253 if XNN_UNPREDICTABLE(a4 != zero) {
1254 a4 = (const float*) ((uintptr_t) a4 + a_offset);
1255 }
1256 const float* restrict a5 = a[5];
1257 assert(a5 != NULL);
1258 if XNN_UNPREDICTABLE(a5 != zero) {
1259 a5 = (const float*) ((uintptr_t) a5 + a_offset);
1260 }
1261 const float* restrict a6 = a[6];
1262 assert(a6 != NULL);
1263 if XNN_UNPREDICTABLE(a6 != zero) {
1264 a6 = (const float*) ((uintptr_t) a6 + a_offset);
1265 }
1266 a += 7;
1267
1268 size_t k = kc;
1269 do {
1270 const __m512 vb0123456789ABCDEF = _mm512_load_ps(w);
1271 w += 16;
1272
1273 const __m512 va0 = _mm512_set1_ps(*a0);
1274 vacc0x0123456789ABCDEF = _mm512_fmadd_ps(va0, vb0123456789ABCDEF, vacc0x0123456789ABCDEF);
1275 const __m512 va1 = _mm512_set1_ps(*a1);
1276 vacc1x0123456789ABCDEF = _mm512_fmadd_ps(va1, vb0123456789ABCDEF, vacc1x0123456789ABCDEF);
1277 const __m512 va2 = _mm512_set1_ps(*a2);
1278 vacc2x0123456789ABCDEF = _mm512_fmadd_ps(va2, vb0123456789ABCDEF, vacc2x0123456789ABCDEF);
1279 const __m512 va3 = _mm512_set1_ps(*a3);
1280 vacc3x0123456789ABCDEF = _mm512_fmadd_ps(va3, vb0123456789ABCDEF, vacc3x0123456789ABCDEF);
1281 const __m512 va4 = _mm512_set1_ps(*a4);
1282 vacc4x0123456789ABCDEF = _mm512_fmadd_ps(va4, vb0123456789ABCDEF, vacc4x0123456789ABCDEF);
1283 const __m512 va5 = _mm512_set1_ps(*a5);
1284 vacc5x0123456789ABCDEF = _mm512_fmadd_ps(va5, vb0123456789ABCDEF, vacc5x0123456789ABCDEF);
1285 const __m512 va6 = _mm512_set1_ps(*a6);
1286 vacc6x0123456789ABCDEF = _mm512_fmadd_ps(va6, vb0123456789ABCDEF, vacc6x0123456789ABCDEF);
1287
1288 a0 += 1;
1289 a1 += 1;
1290 a2 += 1;
1291 a3 += 1;
1292 a4 += 1;
1293 a5 += 1;
1294 a6 += 1;
1295
1296 k -= sizeof(float);
1297 } while (k != 0);
1298 p -= 7 * sizeof(void*);
1299 } while (p != 0);
1300
1301 const __m512 vmin = _mm512_set1_ps(params->scalar.min);
1302 vacc0x0123456789ABCDEF = _mm512_max_ps(vacc0x0123456789ABCDEF, vmin);
1303 vacc1x0123456789ABCDEF = _mm512_max_ps(vacc1x0123456789ABCDEF, vmin);
1304 vacc2x0123456789ABCDEF = _mm512_max_ps(vacc2x0123456789ABCDEF, vmin);
1305 vacc3x0123456789ABCDEF = _mm512_max_ps(vacc3x0123456789ABCDEF, vmin);
1306 vacc4x0123456789ABCDEF = _mm512_max_ps(vacc4x0123456789ABCDEF, vmin);
1307 vacc5x0123456789ABCDEF = _mm512_max_ps(vacc5x0123456789ABCDEF, vmin);
1308 vacc6x0123456789ABCDEF = _mm512_max_ps(vacc6x0123456789ABCDEF, vmin);
1309
1310 const __m512 vmax = _mm512_set1_ps(params->scalar.max);
1311 vacc0x0123456789ABCDEF = _mm512_min_ps(vacc0x0123456789ABCDEF, vmax);
1312 vacc1x0123456789ABCDEF = _mm512_min_ps(vacc1x0123456789ABCDEF, vmax);
1313 vacc2x0123456789ABCDEF = _mm512_min_ps(vacc2x0123456789ABCDEF, vmax);
1314 vacc3x0123456789ABCDEF = _mm512_min_ps(vacc3x0123456789ABCDEF, vmax);
1315 vacc4x0123456789ABCDEF = _mm512_min_ps(vacc4x0123456789ABCDEF, vmax);
1316 vacc5x0123456789ABCDEF = _mm512_min_ps(vacc5x0123456789ABCDEF, vmax);
1317 vacc6x0123456789ABCDEF = _mm512_min_ps(vacc6x0123456789ABCDEF, vmax);
1318
1319 if XNN_LIKELY(nc >= 16) {
1320 _mm512_storeu_ps(c6, vacc6x0123456789ABCDEF);
1321 c6 = (float*) ((uintptr_t) c6 + cn_stride);
1322 _mm512_storeu_ps(c5, vacc5x0123456789ABCDEF);
1323 c5 = (float*) ((uintptr_t) c5 + cn_stride);
1324 _mm512_storeu_ps(c4, vacc4x0123456789ABCDEF);
1325 c4 = (float*) ((uintptr_t) c4 + cn_stride);
1326 _mm512_storeu_ps(c3, vacc3x0123456789ABCDEF);
1327 c3 = (float*) ((uintptr_t) c3 + cn_stride);
1328 _mm512_storeu_ps(c2, vacc2x0123456789ABCDEF);
1329 c2 = (float*) ((uintptr_t) c2 + cn_stride);
1330 _mm512_storeu_ps(c1, vacc1x0123456789ABCDEF);
1331 c1 = (float*) ((uintptr_t) c1 + cn_stride);
1332 _mm512_storeu_ps(c0, vacc0x0123456789ABCDEF);
1333 c0 = (float*) ((uintptr_t) c0 + cn_stride);
1334
1335 a = (const float**restrict) ((uintptr_t) a - ks);
1336 nc -= 16;
1337 } else {
1338 if (nc & 15) {
1339 // Prepare mask for valid 32-bit elements (depends on nc).
1340 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << nc) - UINT32_C(1)));
1341
1342 _mm512_mask_storeu_ps(c6, vmask, vacc6x0123456789ABCDEF);
1343 _mm512_mask_storeu_ps(c5, vmask, vacc5x0123456789ABCDEF);
1344 _mm512_mask_storeu_ps(c4, vmask, vacc4x0123456789ABCDEF);
1345 _mm512_mask_storeu_ps(c3, vmask, vacc3x0123456789ABCDEF);
1346 _mm512_mask_storeu_ps(c2, vmask, vacc2x0123456789ABCDEF);
1347 _mm512_mask_storeu_ps(c1, vmask, vacc1x0123456789ABCDEF);
1348 _mm512_mask_storeu_ps(c0, vmask, vacc0x0123456789ABCDEF);
1349 }
1350
1351 nc = 0;
1352 }
1353 } while (nc != 0);
1354 }
1355
xnn_f32_prelu_ukernel__avx512f_2x16(size_t rows,size_t channels,const float * restrict input,size_t input_stride,const float * restrict weights,float * restrict output,size_t output_stride)1356 void xnn_f32_prelu_ukernel__avx512f_2x16(
1357 size_t rows,
1358 size_t channels,
1359 const float*restrict input,
1360 size_t input_stride,
1361 const float*restrict weights,
1362 float*restrict output,
1363 size_t output_stride)
1364 {
1365 assert(rows != 0);
1366 assert(channels != 0);
1367 assert(channels % sizeof(float) == 0);
1368
1369 const float* i0 = input;
1370 float* o0 = output;
1371 const float* i1 = (const float*) ((uintptr_t) i0 + input_stride);
1372 float* o1 = (float*) ((uintptr_t) o0 + output_stride);
1373
1374 const size_t input_increment = input_stride * 2 - channels;
1375 const size_t output_increment = output_stride * 2 - channels;
1376
1377 const __m512 vzero = _mm512_setzero_ps();
1378 do {
1379 if XNN_UNPREDICTABLE(rows < 2) {
1380 i1 = i0;
1381 o1 = o0;
1382 }
1383
1384 const float* w = weights;
1385 size_t c = channels;
1386 for (; c >= 16 * sizeof(float); c -= 16 * sizeof(float)) {
1387 const __m512 vw0123456789ABCDEF = _mm512_load_ps(w);
1388 w += 16;
1389
1390 const __m512 vi0x0123456789ABCDEF = _mm512_loadu_ps(i0);
1391 i0 += 16;
1392 const __m512 vi1x0123456789ABCDEF = _mm512_loadu_ps(i1);
1393 i1 += 16;
1394
1395 const __mmask16 vsign0x0123456789ABCDEF = _mm512_cmp_ps_mask(vi0x0123456789ABCDEF, vzero, _CMP_LT_OQ);
1396 const __m512 vacc0x0123456789ABCDEF = _mm512_mask_mul_ps(vi0x0123456789ABCDEF, vsign0x0123456789ABCDEF, vi0x0123456789ABCDEF, vw0123456789ABCDEF);
1397 const __mmask16 vsign1x0123456789ABCDEF = _mm512_cmp_ps_mask(vi1x0123456789ABCDEF, vzero, _CMP_LT_OQ);
1398 const __m512 vacc1x0123456789ABCDEF = _mm512_mask_mul_ps(vi1x0123456789ABCDEF, vsign1x0123456789ABCDEF, vi1x0123456789ABCDEF, vw0123456789ABCDEF);
1399
1400 _mm512_storeu_ps(o0, vacc0x0123456789ABCDEF);
1401 o0 += 16;
1402 _mm512_storeu_ps(o1, vacc1x0123456789ABCDEF);
1403 o1 += 16;
1404 }
1405 if XNN_UNLIKELY(c != 0) {
1406 assert(c >= 1 * sizeof(float));
1407 assert(c <= 15 * sizeof(float));
1408 // Prepare mask for valid 32-bit elements (depends on c).
1409 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << (c >> 2 /* log2(sizeof(float))*/)) - UINT32_C(1)));
1410
1411 const __m512 vw = _mm512_maskz_loadu_ps(vmask, w);
1412
1413 const __m512 vi0 = _mm512_maskz_loadu_ps(vmask, i0);
1414 i0 = (const float*) ((uintptr_t) i0 + c);
1415 const __m512 vi1 = _mm512_maskz_loadu_ps(vmask, i1);
1416 i1 = (const float*) ((uintptr_t) i1 + c);
1417
1418 const __mmask16 vsign0 = _mm512_cmp_ps_mask(vi0, vzero, _CMP_LT_OQ);
1419 const __m512 vacc0 = _mm512_mask_mul_ps(vi0, vsign0, vi0, vw);
1420 const __mmask16 vsign1 = _mm512_cmp_ps_mask(vi1, vzero, _CMP_LT_OQ);
1421 const __m512 vacc1 = _mm512_mask_mul_ps(vi1, vsign1, vi1, vw);
1422
1423 _mm512_mask_storeu_ps(o0, vmask, vacc0);
1424 o0 = (float*) ((uintptr_t) o0 + c);
1425 _mm512_mask_storeu_ps(o1, vmask, vacc1);
1426 o1 = (float*) ((uintptr_t) o1 + c);
1427 }
1428 i0 = (const float*) ((uintptr_t) i0 + input_increment);
1429 o0 = (float*) ((uintptr_t) o0 + output_increment);
1430 i1 = (const float*) ((uintptr_t) i1 + input_increment);
1431 o1 = (float*) ((uintptr_t) o1 + output_increment);
1432 rows = doz(rows, 2);
1433 } while (rows != 0);
1434 }
1435
xnn_f32_vadd_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1436 void xnn_f32_vadd_minmax_ukernel__avx512f_x32(
1437 size_t n,
1438 const float* a,
1439 const float* b,
1440 float* y,
1441 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
1442 {
1443 assert(n != 0);
1444 assert(n % sizeof(float) == 0);
1445 assert(a != NULL);
1446 assert(b != NULL);
1447 assert(y != NULL);
1448
1449 const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
1450 const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
1451
1452 for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
1453 const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
1454 const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
1455 a += 32;
1456
1457 const __m512 vb0123456789ABCDEF = _mm512_loadu_ps(b);
1458 const __m512 vbGHIJKLMNOPQRSTUV = _mm512_loadu_ps(b + 16);
1459 b += 32;
1460
1461 __m512 vy0123456789ABCDEF = _mm512_add_ps(va0123456789ABCDEF, vb0123456789ABCDEF);
1462 __m512 vyGHIJKLMNOPQRSTUV = _mm512_add_ps(vaGHIJKLMNOPQRSTUV, vbGHIJKLMNOPQRSTUV);
1463
1464
1465 vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
1466 vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
1467
1468 vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
1469 vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
1470
1471 _mm512_storeu_ps(y, vy0123456789ABCDEF);
1472 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
1473 y += 32;
1474 }
1475 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
1476 const __m512 va = _mm512_loadu_ps(a);
1477 a += 16;
1478
1479 const __m512 vb = _mm512_loadu_ps(b);
1480 b += 16;
1481
1482 __m512 vy = _mm512_add_ps(va, vb);
1483 vy = _mm512_max_ps(vy, vy_min);
1484 vy = _mm512_min_ps(vy, vy_max);
1485 _mm512_storeu_ps(y, vy);
1486 y += 16;
1487 }
1488 if XNN_UNLIKELY(n != 0) {
1489 assert(n >= 1 * sizeof(float));
1490 assert(n <= 15 * sizeof(float));
1491 // Prepare mask for valid 32-bit elements (depends on n).
1492 n >>= 2 /* log2(sizeof(float)) */;
1493 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
1494
1495 const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
1496 const __m512 vb = _mm512_maskz_loadu_ps(vmask, b);
1497
1498 __m512 vy = _mm512_add_ps(va, vb);
1499 vy = _mm512_max_ps(vy, vy_min);
1500 vy = _mm512_min_ps(vy, vy_max);
1501 _mm512_mask_storeu_ps(y, vmask, vy);
1502 }
1503 }
1504
xnn_f32_vaddc_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1505 void xnn_f32_vaddc_minmax_ukernel__avx512f_x32(
1506 size_t n,
1507 const float* a,
1508 const float* b,
1509 float* y,
1510 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
1511 {
1512 assert(n != 0);
1513 assert(n % sizeof(float) == 0);
1514 assert(a != NULL);
1515 assert(b != NULL);
1516 assert(y != NULL);
1517
1518 const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
1519 const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
1520
1521 const __m512 vb = _mm512_set1_ps(*b);
1522 for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
1523 const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
1524 const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
1525 a += 32;
1526
1527 __m512 vy0123456789ABCDEF = _mm512_add_ps(va0123456789ABCDEF, vb);
1528 __m512 vyGHIJKLMNOPQRSTUV = _mm512_add_ps(vaGHIJKLMNOPQRSTUV, vb);
1529
1530
1531 vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
1532 vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
1533
1534 vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
1535 vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
1536
1537 _mm512_storeu_ps(y, vy0123456789ABCDEF);
1538 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
1539 y += 32;
1540 }
1541 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
1542 const __m512 va = _mm512_loadu_ps(a);
1543 a += 16;
1544
1545 __m512 vy = _mm512_add_ps(va, vb);
1546 vy = _mm512_max_ps(vy, vy_min);
1547 vy = _mm512_min_ps(vy, vy_max);
1548 _mm512_storeu_ps(y, vy);
1549 y += 16;
1550 }
1551 if XNN_UNLIKELY(n != 0) {
1552 assert(n >= 1 * sizeof(float));
1553 assert(n <= 15 * sizeof(float));
1554 // Prepare mask for valid 32-bit elements (depends on n).
1555 n >>= 2 /* log2(sizeof(float)) */;
1556 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
1557
1558 const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
1559
1560 __m512 vy = _mm512_add_ps(va, vb);
1561 vy = _mm512_max_ps(vy, vy_min);
1562 vy = _mm512_min_ps(vy, vy_max);
1563 _mm512_mask_storeu_ps(y, vmask, vy);
1564 }
1565 }
1566
xnn_f32_vdiv_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1567 void xnn_f32_vdiv_minmax_ukernel__avx512f_x32(
1568 size_t n,
1569 const float* a,
1570 const float* b,
1571 float* y,
1572 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
1573 {
1574 assert(n != 0);
1575 assert(n % sizeof(float) == 0);
1576 assert(a != NULL);
1577 assert(b != NULL);
1578 assert(y != NULL);
1579
1580 const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
1581 const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
1582
1583 for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
1584 const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
1585 const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
1586 a += 32;
1587
1588 const __m512 vb0123456789ABCDEF = _mm512_loadu_ps(b);
1589 const __m512 vbGHIJKLMNOPQRSTUV = _mm512_loadu_ps(b + 16);
1590 b += 32;
1591
1592 __m512 vy0123456789ABCDEF = _mm512_div_ps(va0123456789ABCDEF, vb0123456789ABCDEF);
1593 __m512 vyGHIJKLMNOPQRSTUV = _mm512_div_ps(vaGHIJKLMNOPQRSTUV, vbGHIJKLMNOPQRSTUV);
1594
1595
1596 vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
1597 vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
1598
1599 vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
1600 vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
1601
1602 _mm512_storeu_ps(y, vy0123456789ABCDEF);
1603 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
1604 y += 32;
1605 }
1606 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
1607 const __m512 va = _mm512_loadu_ps(a);
1608 a += 16;
1609
1610 const __m512 vb = _mm512_loadu_ps(b);
1611 b += 16;
1612
1613 __m512 vy = _mm512_div_ps(va, vb);
1614 vy = _mm512_max_ps(vy, vy_min);
1615 vy = _mm512_min_ps(vy, vy_max);
1616 _mm512_storeu_ps(y, vy);
1617 y += 16;
1618 }
1619 if XNN_UNLIKELY(n != 0) {
1620 assert(n >= 1 * sizeof(float));
1621 assert(n <= 15 * sizeof(float));
1622 // Prepare mask for valid 32-bit elements (depends on n).
1623 n >>= 2 /* log2(sizeof(float)) */;
1624 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
1625
1626 const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
1627 const __m512 vb = _mm512_maskz_loadu_ps(vmask, b);
1628
1629 __m512 vy = _mm512_div_ps(va, vb);
1630 vy = _mm512_max_ps(vy, vy_min);
1631 vy = _mm512_min_ps(vy, vy_max);
1632 _mm512_mask_storeu_ps(y, vmask, vy);
1633 }
1634 }
1635
xnn_f32_vdivc_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1636 void xnn_f32_vdivc_minmax_ukernel__avx512f_x32(
1637 size_t n,
1638 const float* a,
1639 const float* b,
1640 float* y,
1641 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
1642 {
1643 assert(n != 0);
1644 assert(n % sizeof(float) == 0);
1645 assert(a != NULL);
1646 assert(b != NULL);
1647 assert(y != NULL);
1648
1649 const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
1650 const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
1651
1652 const __m512 vb = _mm512_set1_ps(*b);
1653 for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
1654 const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
1655 const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
1656 a += 32;
1657
1658 __m512 vy0123456789ABCDEF = _mm512_div_ps(va0123456789ABCDEF, vb);
1659 __m512 vyGHIJKLMNOPQRSTUV = _mm512_div_ps(vaGHIJKLMNOPQRSTUV, vb);
1660
1661
1662 vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
1663 vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
1664
1665 vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
1666 vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
1667
1668 _mm512_storeu_ps(y, vy0123456789ABCDEF);
1669 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
1670 y += 32;
1671 }
1672 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
1673 const __m512 va = _mm512_loadu_ps(a);
1674 a += 16;
1675
1676 __m512 vy = _mm512_div_ps(va, vb);
1677 vy = _mm512_max_ps(vy, vy_min);
1678 vy = _mm512_min_ps(vy, vy_max);
1679 _mm512_storeu_ps(y, vy);
1680 y += 16;
1681 }
1682 if XNN_UNLIKELY(n != 0) {
1683 assert(n >= 1 * sizeof(float));
1684 assert(n <= 15 * sizeof(float));
1685 // Prepare mask for valid 32-bit elements (depends on n).
1686 n >>= 2 /* log2(sizeof(float)) */;
1687 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
1688
1689 const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
1690
1691 __m512 vy = _mm512_div_ps(va, vb);
1692 vy = _mm512_max_ps(vy, vy_min);
1693 vy = _mm512_min_ps(vy, vy_max);
1694 _mm512_mask_storeu_ps(y, vmask, vy);
1695 }
1696 }
1697
xnn_f32_vmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS (1)])1698 void xnn_f32_vmax_ukernel__avx512f_x32(
1699 size_t n,
1700 const float* a,
1701 const float* b,
1702 float* y,
1703 const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
1704 {
1705 assert(n != 0);
1706 assert(n % sizeof(float) == 0);
1707 assert(a != NULL);
1708 assert(b != NULL);
1709 assert(y != NULL);
1710
1711
1712 for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
1713 const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
1714 const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
1715 a += 32;
1716
1717 const __m512 vb0123456789ABCDEF = _mm512_loadu_ps(b);
1718 const __m512 vbGHIJKLMNOPQRSTUV = _mm512_loadu_ps(b + 16);
1719 b += 32;
1720
1721 __m512 vy0123456789ABCDEF = _mm512_max_ps(va0123456789ABCDEF, vb0123456789ABCDEF);
1722 __m512 vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vaGHIJKLMNOPQRSTUV, vbGHIJKLMNOPQRSTUV);
1723
1724
1725
1726 _mm512_storeu_ps(y, vy0123456789ABCDEF);
1727 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
1728 y += 32;
1729 }
1730 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
1731 const __m512 va = _mm512_loadu_ps(a);
1732 a += 16;
1733
1734 const __m512 vb = _mm512_loadu_ps(b);
1735 b += 16;
1736
1737 __m512 vy = _mm512_max_ps(va, vb);
1738 _mm512_storeu_ps(y, vy);
1739 y += 16;
1740 }
1741 if XNN_UNLIKELY(n != 0) {
1742 assert(n >= 1 * sizeof(float));
1743 assert(n <= 15 * sizeof(float));
1744 // Prepare mask for valid 32-bit elements (depends on n).
1745 n >>= 2 /* log2(sizeof(float)) */;
1746 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
1747
1748 const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
1749 const __m512 vb = _mm512_maskz_loadu_ps(vmask, b);
1750
1751 __m512 vy = _mm512_max_ps(va, vb);
1752 _mm512_mask_storeu_ps(y, vmask, vy);
1753 }
1754 }
1755
xnn_f32_vmaxc_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS (1)])1756 void xnn_f32_vmaxc_ukernel__avx512f_x32(
1757 size_t n,
1758 const float* a,
1759 const float* b,
1760 float* y,
1761 const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
1762 {
1763 assert(n != 0);
1764 assert(n % sizeof(float) == 0);
1765 assert(a != NULL);
1766 assert(b != NULL);
1767 assert(y != NULL);
1768
1769
1770 const __m512 vb = _mm512_set1_ps(*b);
1771 for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
1772 const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
1773 const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
1774 a += 32;
1775
1776 __m512 vy0123456789ABCDEF = _mm512_max_ps(va0123456789ABCDEF, vb);
1777 __m512 vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vaGHIJKLMNOPQRSTUV, vb);
1778
1779
1780
1781 _mm512_storeu_ps(y, vy0123456789ABCDEF);
1782 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
1783 y += 32;
1784 }
1785 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
1786 const __m512 va = _mm512_loadu_ps(a);
1787 a += 16;
1788
1789 __m512 vy = _mm512_max_ps(va, vb);
1790 _mm512_storeu_ps(y, vy);
1791 y += 16;
1792 }
1793 if XNN_UNLIKELY(n != 0) {
1794 assert(n >= 1 * sizeof(float));
1795 assert(n <= 15 * sizeof(float));
1796 // Prepare mask for valid 32-bit elements (depends on n).
1797 n >>= 2 /* log2(sizeof(float)) */;
1798 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
1799
1800 const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
1801
1802 __m512 vy = _mm512_max_ps(va, vb);
1803 _mm512_mask_storeu_ps(y, vmask, vy);
1804 }
1805 }
1806
xnn_f32_vmin_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS (1)])1807 void xnn_f32_vmin_ukernel__avx512f_x32(
1808 size_t n,
1809 const float* a,
1810 const float* b,
1811 float* y,
1812 const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
1813 {
1814 assert(n != 0);
1815 assert(n % sizeof(float) == 0);
1816 assert(a != NULL);
1817 assert(b != NULL);
1818 assert(y != NULL);
1819
1820
1821 for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
1822 const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
1823 const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
1824 a += 32;
1825
1826 const __m512 vb0123456789ABCDEF = _mm512_loadu_ps(b);
1827 const __m512 vbGHIJKLMNOPQRSTUV = _mm512_loadu_ps(b + 16);
1828 b += 32;
1829
1830 __m512 vy0123456789ABCDEF = _mm512_min_ps(va0123456789ABCDEF, vb0123456789ABCDEF);
1831 __m512 vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vaGHIJKLMNOPQRSTUV, vbGHIJKLMNOPQRSTUV);
1832
1833
1834
1835 _mm512_storeu_ps(y, vy0123456789ABCDEF);
1836 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
1837 y += 32;
1838 }
1839 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
1840 const __m512 va = _mm512_loadu_ps(a);
1841 a += 16;
1842
1843 const __m512 vb = _mm512_loadu_ps(b);
1844 b += 16;
1845
1846 __m512 vy = _mm512_min_ps(va, vb);
1847 _mm512_storeu_ps(y, vy);
1848 y += 16;
1849 }
1850 if XNN_UNLIKELY(n != 0) {
1851 assert(n >= 1 * sizeof(float));
1852 assert(n <= 15 * sizeof(float));
1853 // Prepare mask for valid 32-bit elements (depends on n).
1854 n >>= 2 /* log2(sizeof(float)) */;
1855 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
1856
1857 const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
1858 const __m512 vb = _mm512_maskz_loadu_ps(vmask, b);
1859
1860 __m512 vy = _mm512_min_ps(va, vb);
1861 _mm512_mask_storeu_ps(y, vmask, vy);
1862 }
1863 }
1864
xnn_f32_vminc_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS (1)])1865 void xnn_f32_vminc_ukernel__avx512f_x32(
1866 size_t n,
1867 const float* a,
1868 const float* b,
1869 float* y,
1870 const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
1871 {
1872 assert(n != 0);
1873 assert(n % sizeof(float) == 0);
1874 assert(a != NULL);
1875 assert(b != NULL);
1876 assert(y != NULL);
1877
1878
1879 const __m512 vb = _mm512_set1_ps(*b);
1880 for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
1881 const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
1882 const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
1883 a += 32;
1884
1885 __m512 vy0123456789ABCDEF = _mm512_min_ps(va0123456789ABCDEF, vb);
1886 __m512 vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vaGHIJKLMNOPQRSTUV, vb);
1887
1888
1889
1890 _mm512_storeu_ps(y, vy0123456789ABCDEF);
1891 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
1892 y += 32;
1893 }
1894 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
1895 const __m512 va = _mm512_loadu_ps(a);
1896 a += 16;
1897
1898 __m512 vy = _mm512_min_ps(va, vb);
1899 _mm512_storeu_ps(y, vy);
1900 y += 16;
1901 }
1902 if XNN_UNLIKELY(n != 0) {
1903 assert(n >= 1 * sizeof(float));
1904 assert(n <= 15 * sizeof(float));
1905 // Prepare mask for valid 32-bit elements (depends on n).
1906 n >>= 2 /* log2(sizeof(float)) */;
1907 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
1908
1909 const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
1910
1911 __m512 vy = _mm512_min_ps(va, vb);
1912 _mm512_mask_storeu_ps(y, vmask, vy);
1913 }
1914 }
1915
xnn_f32_vmul_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1916 void xnn_f32_vmul_minmax_ukernel__avx512f_x32(
1917 size_t n,
1918 const float* a,
1919 const float* b,
1920 float* y,
1921 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
1922 {
1923 assert(n != 0);
1924 assert(n % sizeof(float) == 0);
1925 assert(a != NULL);
1926 assert(b != NULL);
1927 assert(y != NULL);
1928
1929 const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
1930 const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
1931
1932 for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
1933 const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
1934 const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
1935 a += 32;
1936
1937 const __m512 vb0123456789ABCDEF = _mm512_loadu_ps(b);
1938 const __m512 vbGHIJKLMNOPQRSTUV = _mm512_loadu_ps(b + 16);
1939 b += 32;
1940
1941 __m512 vy0123456789ABCDEF = _mm512_mul_ps(va0123456789ABCDEF, vb0123456789ABCDEF);
1942 __m512 vyGHIJKLMNOPQRSTUV = _mm512_mul_ps(vaGHIJKLMNOPQRSTUV, vbGHIJKLMNOPQRSTUV);
1943
1944
1945 vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
1946 vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
1947
1948 vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
1949 vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
1950
1951 _mm512_storeu_ps(y, vy0123456789ABCDEF);
1952 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
1953 y += 32;
1954 }
1955 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
1956 const __m512 va = _mm512_loadu_ps(a);
1957 a += 16;
1958
1959 const __m512 vb = _mm512_loadu_ps(b);
1960 b += 16;
1961
1962 __m512 vy = _mm512_mul_ps(va, vb);
1963 vy = _mm512_max_ps(vy, vy_min);
1964 vy = _mm512_min_ps(vy, vy_max);
1965 _mm512_storeu_ps(y, vy);
1966 y += 16;
1967 }
1968 if XNN_UNLIKELY(n != 0) {
1969 assert(n >= 1 * sizeof(float));
1970 assert(n <= 15 * sizeof(float));
1971 // Prepare mask for valid 32-bit elements (depends on n).
1972 n >>= 2 /* log2(sizeof(float)) */;
1973 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
1974
1975 const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
1976 const __m512 vb = _mm512_maskz_loadu_ps(vmask, b);
1977
1978 __m512 vy = _mm512_mul_ps(va, vb);
1979 vy = _mm512_max_ps(vy, vy_min);
1980 vy = _mm512_min_ps(vy, vy_max);
1981 _mm512_mask_storeu_ps(y, vmask, vy);
1982 }
1983 }
1984
xnn_f32_vmulc_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1985 void xnn_f32_vmulc_minmax_ukernel__avx512f_x32(
1986 size_t n,
1987 const float* a,
1988 const float* b,
1989 float* y,
1990 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
1991 {
1992 assert(n != 0);
1993 assert(n % sizeof(float) == 0);
1994 assert(a != NULL);
1995 assert(b != NULL);
1996 assert(y != NULL);
1997
1998 const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
1999 const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
2000
2001 const __m512 vb = _mm512_set1_ps(*b);
2002 for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
2003 const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
2004 const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
2005 a += 32;
2006
2007 __m512 vy0123456789ABCDEF = _mm512_mul_ps(va0123456789ABCDEF, vb);
2008 __m512 vyGHIJKLMNOPQRSTUV = _mm512_mul_ps(vaGHIJKLMNOPQRSTUV, vb);
2009
2010
2011 vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
2012 vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
2013
2014 vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
2015 vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
2016
2017 _mm512_storeu_ps(y, vy0123456789ABCDEF);
2018 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
2019 y += 32;
2020 }
2021 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2022 const __m512 va = _mm512_loadu_ps(a);
2023 a += 16;
2024
2025 __m512 vy = _mm512_mul_ps(va, vb);
2026 vy = _mm512_max_ps(vy, vy_min);
2027 vy = _mm512_min_ps(vy, vy_max);
2028 _mm512_storeu_ps(y, vy);
2029 y += 16;
2030 }
2031 if XNN_UNLIKELY(n != 0) {
2032 assert(n >= 1 * sizeof(float));
2033 assert(n <= 15 * sizeof(float));
2034 // Prepare mask for valid 32-bit elements (depends on n).
2035 n >>= 2 /* log2(sizeof(float)) */;
2036 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2037
2038 const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
2039
2040 __m512 vy = _mm512_mul_ps(va, vb);
2041 vy = _mm512_max_ps(vy, vy_min);
2042 vy = _mm512_min_ps(vy, vy_max);
2043 _mm512_mask_storeu_ps(y, vmask, vy);
2044 }
2045 }
2046
xnn_f32_vrdivc_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2047 void xnn_f32_vrdivc_minmax_ukernel__avx512f_x32(
2048 size_t n,
2049 const float* a,
2050 const float* b,
2051 float* y,
2052 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
2053 {
2054 assert(n != 0);
2055 assert(n % sizeof(float) == 0);
2056 assert(a != NULL);
2057 assert(b != NULL);
2058 assert(y != NULL);
2059
2060 const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
2061 const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
2062
2063 const __m512 vb = _mm512_set1_ps(*b);
2064 for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
2065 const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
2066 const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
2067 a += 32;
2068
2069 __m512 vy0123456789ABCDEF = _mm512_div_ps(vb, va0123456789ABCDEF);
2070 __m512 vyGHIJKLMNOPQRSTUV = _mm512_div_ps(vb, vaGHIJKLMNOPQRSTUV);
2071
2072
2073 vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
2074 vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
2075
2076 vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
2077 vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
2078
2079 _mm512_storeu_ps(y, vy0123456789ABCDEF);
2080 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
2081 y += 32;
2082 }
2083 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2084 const __m512 va = _mm512_loadu_ps(a);
2085 a += 16;
2086
2087 __m512 vy = _mm512_div_ps(vb, va);
2088 vy = _mm512_max_ps(vy, vy_min);
2089 vy = _mm512_min_ps(vy, vy_max);
2090 _mm512_storeu_ps(y, vy);
2091 y += 16;
2092 }
2093 if XNN_UNLIKELY(n != 0) {
2094 assert(n >= 1 * sizeof(float));
2095 assert(n <= 15 * sizeof(float));
2096 // Prepare mask for valid 32-bit elements (depends on n).
2097 n >>= 2 /* log2(sizeof(float)) */;
2098 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2099
2100 const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
2101
2102 __m512 vy = _mm512_div_ps(vb, va);
2103 vy = _mm512_max_ps(vy, vy_min);
2104 vy = _mm512_min_ps(vy, vy_max);
2105 _mm512_mask_storeu_ps(y, vmask, vy);
2106 }
2107 }
2108
xnn_f32_vrsubc_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2109 void xnn_f32_vrsubc_minmax_ukernel__avx512f_x32(
2110 size_t n,
2111 const float* a,
2112 const float* b,
2113 float* y,
2114 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
2115 {
2116 assert(n != 0);
2117 assert(n % sizeof(float) == 0);
2118 assert(a != NULL);
2119 assert(b != NULL);
2120 assert(y != NULL);
2121
2122 const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
2123 const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
2124
2125 const __m512 vb = _mm512_set1_ps(*b);
2126 for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
2127 const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
2128 const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
2129 a += 32;
2130
2131 __m512 vy0123456789ABCDEF = _mm512_sub_ps(vb, va0123456789ABCDEF);
2132 __m512 vyGHIJKLMNOPQRSTUV = _mm512_sub_ps(vb, vaGHIJKLMNOPQRSTUV);
2133
2134
2135 vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
2136 vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
2137
2138 vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
2139 vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
2140
2141 _mm512_storeu_ps(y, vy0123456789ABCDEF);
2142 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
2143 y += 32;
2144 }
2145 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2146 const __m512 va = _mm512_loadu_ps(a);
2147 a += 16;
2148
2149 __m512 vy = _mm512_sub_ps(vb, va);
2150 vy = _mm512_max_ps(vy, vy_min);
2151 vy = _mm512_min_ps(vy, vy_max);
2152 _mm512_storeu_ps(y, vy);
2153 y += 16;
2154 }
2155 if XNN_UNLIKELY(n != 0) {
2156 assert(n >= 1 * sizeof(float));
2157 assert(n <= 15 * sizeof(float));
2158 // Prepare mask for valid 32-bit elements (depends on n).
2159 n >>= 2 /* log2(sizeof(float)) */;
2160 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2161
2162 const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
2163
2164 __m512 vy = _mm512_sub_ps(vb, va);
2165 vy = _mm512_max_ps(vy, vy_min);
2166 vy = _mm512_min_ps(vy, vy_max);
2167 _mm512_mask_storeu_ps(y, vmask, vy);
2168 }
2169 }
2170
xnn_f32_vsqrdiff_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS (1)])2171 void xnn_f32_vsqrdiff_ukernel__avx512f_x32(
2172 size_t n,
2173 const float* a,
2174 const float* b,
2175 float* y,
2176 const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
2177 {
2178 assert(n != 0);
2179 assert(n % sizeof(float) == 0);
2180 assert(a != NULL);
2181 assert(b != NULL);
2182 assert(y != NULL);
2183
2184
2185 for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
2186 const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
2187 const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
2188 a += 32;
2189
2190 const __m512 vb0123456789ABCDEF = _mm512_loadu_ps(b);
2191 const __m512 vbGHIJKLMNOPQRSTUV = _mm512_loadu_ps(b + 16);
2192 b += 32;
2193
2194 __m512 vy0123456789ABCDEF = _mm512_sub_ps(va0123456789ABCDEF, vb0123456789ABCDEF);
2195 __m512 vyGHIJKLMNOPQRSTUV = _mm512_sub_ps(vaGHIJKLMNOPQRSTUV, vbGHIJKLMNOPQRSTUV);
2196
2197 vy0123456789ABCDEF = _mm512_mul_ps(vy0123456789ABCDEF, vy0123456789ABCDEF);
2198 vyGHIJKLMNOPQRSTUV = _mm512_mul_ps(vyGHIJKLMNOPQRSTUV, vyGHIJKLMNOPQRSTUV);
2199
2200
2201 _mm512_storeu_ps(y, vy0123456789ABCDEF);
2202 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
2203 y += 32;
2204 }
2205 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2206 const __m512 va = _mm512_loadu_ps(a);
2207 a += 16;
2208
2209 const __m512 vb = _mm512_loadu_ps(b);
2210 b += 16;
2211
2212 __m512 vy = _mm512_sub_ps(va, vb);
2213 vy = _mm512_mul_ps(vy, vy);
2214 _mm512_storeu_ps(y, vy);
2215 y += 16;
2216 }
2217 if XNN_UNLIKELY(n != 0) {
2218 assert(n >= 1 * sizeof(float));
2219 assert(n <= 15 * sizeof(float));
2220 // Prepare mask for valid 32-bit elements (depends on n).
2221 n >>= 2 /* log2(sizeof(float)) */;
2222 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2223
2224 const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
2225 const __m512 vb = _mm512_maskz_loadu_ps(vmask, b);
2226
2227 __m512 vy = _mm512_sub_ps(va, vb);
2228 vy = _mm512_mul_ps(vy, vy);
2229 _mm512_mask_storeu_ps(y, vmask, vy);
2230 }
2231 }
2232
xnn_f32_vsqrdiffc_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS (1)])2233 void xnn_f32_vsqrdiffc_ukernel__avx512f_x32(
2234 size_t n,
2235 const float* a,
2236 const float* b,
2237 float* y,
2238 const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
2239 {
2240 assert(n != 0);
2241 assert(n % sizeof(float) == 0);
2242 assert(a != NULL);
2243 assert(b != NULL);
2244 assert(y != NULL);
2245
2246
2247 const __m512 vb = _mm512_set1_ps(*b);
2248 for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
2249 const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
2250 const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
2251 a += 32;
2252
2253 __m512 vy0123456789ABCDEF = _mm512_sub_ps(va0123456789ABCDEF, vb);
2254 __m512 vyGHIJKLMNOPQRSTUV = _mm512_sub_ps(vaGHIJKLMNOPQRSTUV, vb);
2255
2256 vy0123456789ABCDEF = _mm512_mul_ps(vy0123456789ABCDEF, vy0123456789ABCDEF);
2257 vyGHIJKLMNOPQRSTUV = _mm512_mul_ps(vyGHIJKLMNOPQRSTUV, vyGHIJKLMNOPQRSTUV);
2258
2259
2260 _mm512_storeu_ps(y, vy0123456789ABCDEF);
2261 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
2262 y += 32;
2263 }
2264 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2265 const __m512 va = _mm512_loadu_ps(a);
2266 a += 16;
2267
2268 __m512 vy = _mm512_sub_ps(va, vb);
2269 vy = _mm512_mul_ps(vy, vy);
2270 _mm512_storeu_ps(y, vy);
2271 y += 16;
2272 }
2273 if XNN_UNLIKELY(n != 0) {
2274 assert(n >= 1 * sizeof(float));
2275 assert(n <= 15 * sizeof(float));
2276 // Prepare mask for valid 32-bit elements (depends on n).
2277 n >>= 2 /* log2(sizeof(float)) */;
2278 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2279
2280 const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
2281
2282 __m512 vy = _mm512_sub_ps(va, vb);
2283 vy = _mm512_mul_ps(vy, vy);
2284 _mm512_mask_storeu_ps(y, vmask, vy);
2285 }
2286 }
2287
xnn_f32_vsub_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2288 void xnn_f32_vsub_minmax_ukernel__avx512f_x32(
2289 size_t n,
2290 const float* a,
2291 const float* b,
2292 float* y,
2293 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
2294 {
2295 assert(n != 0);
2296 assert(n % sizeof(float) == 0);
2297 assert(a != NULL);
2298 assert(b != NULL);
2299 assert(y != NULL);
2300
2301 const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
2302 const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
2303
2304 for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
2305 const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
2306 const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
2307 a += 32;
2308
2309 const __m512 vb0123456789ABCDEF = _mm512_loadu_ps(b);
2310 const __m512 vbGHIJKLMNOPQRSTUV = _mm512_loadu_ps(b + 16);
2311 b += 32;
2312
2313 __m512 vy0123456789ABCDEF = _mm512_sub_ps(va0123456789ABCDEF, vb0123456789ABCDEF);
2314 __m512 vyGHIJKLMNOPQRSTUV = _mm512_sub_ps(vaGHIJKLMNOPQRSTUV, vbGHIJKLMNOPQRSTUV);
2315
2316
2317 vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
2318 vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
2319
2320 vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
2321 vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
2322
2323 _mm512_storeu_ps(y, vy0123456789ABCDEF);
2324 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
2325 y += 32;
2326 }
2327 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2328 const __m512 va = _mm512_loadu_ps(a);
2329 a += 16;
2330
2331 const __m512 vb = _mm512_loadu_ps(b);
2332 b += 16;
2333
2334 __m512 vy = _mm512_sub_ps(va, vb);
2335 vy = _mm512_max_ps(vy, vy_min);
2336 vy = _mm512_min_ps(vy, vy_max);
2337 _mm512_storeu_ps(y, vy);
2338 y += 16;
2339 }
2340 if XNN_UNLIKELY(n != 0) {
2341 assert(n >= 1 * sizeof(float));
2342 assert(n <= 15 * sizeof(float));
2343 // Prepare mask for valid 32-bit elements (depends on n).
2344 n >>= 2 /* log2(sizeof(float)) */;
2345 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2346
2347 const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
2348 const __m512 vb = _mm512_maskz_loadu_ps(vmask, b);
2349
2350 __m512 vy = _mm512_sub_ps(va, vb);
2351 vy = _mm512_max_ps(vy, vy_min);
2352 vy = _mm512_min_ps(vy, vy_max);
2353 _mm512_mask_storeu_ps(y, vmask, vy);
2354 }
2355 }
2356
xnn_f32_vsubc_minmax_ukernel__avx512f_x32(size_t n,const float * a,const float * b,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2357 void xnn_f32_vsubc_minmax_ukernel__avx512f_x32(
2358 size_t n,
2359 const float* a,
2360 const float* b,
2361 float* y,
2362 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
2363 {
2364 assert(n != 0);
2365 assert(n % sizeof(float) == 0);
2366 assert(a != NULL);
2367 assert(b != NULL);
2368 assert(y != NULL);
2369
2370 const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
2371 const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
2372
2373 const __m512 vb = _mm512_set1_ps(*b);
2374 for (; n >= 32 * sizeof(float); n -= 32 * sizeof(float)) {
2375 const __m512 va0123456789ABCDEF = _mm512_loadu_ps(a);
2376 const __m512 vaGHIJKLMNOPQRSTUV = _mm512_loadu_ps(a + 16);
2377 a += 32;
2378
2379 __m512 vy0123456789ABCDEF = _mm512_sub_ps(va0123456789ABCDEF, vb);
2380 __m512 vyGHIJKLMNOPQRSTUV = _mm512_sub_ps(vaGHIJKLMNOPQRSTUV, vb);
2381
2382
2383 vy0123456789ABCDEF = _mm512_max_ps(vy0123456789ABCDEF, vy_min);
2384 vyGHIJKLMNOPQRSTUV = _mm512_max_ps(vyGHIJKLMNOPQRSTUV, vy_min);
2385
2386 vy0123456789ABCDEF = _mm512_min_ps(vy0123456789ABCDEF, vy_max);
2387 vyGHIJKLMNOPQRSTUV = _mm512_min_ps(vyGHIJKLMNOPQRSTUV, vy_max);
2388
2389 _mm512_storeu_ps(y, vy0123456789ABCDEF);
2390 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
2391 y += 32;
2392 }
2393 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2394 const __m512 va = _mm512_loadu_ps(a);
2395 a += 16;
2396
2397 __m512 vy = _mm512_sub_ps(va, vb);
2398 vy = _mm512_max_ps(vy, vy_min);
2399 vy = _mm512_min_ps(vy, vy_max);
2400 _mm512_storeu_ps(y, vy);
2401 y += 16;
2402 }
2403 if XNN_UNLIKELY(n != 0) {
2404 assert(n >= 1 * sizeof(float));
2405 assert(n <= 15 * sizeof(float));
2406 // Prepare mask for valid 32-bit elements (depends on n).
2407 n >>= 2 /* log2(sizeof(float)) */;
2408 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2409
2410 const __m512 va = _mm512_maskz_loadu_ps(vmask, a);
2411
2412 __m512 vy = _mm512_sub_ps(va, vb);
2413 vy = _mm512_max_ps(vy, vy_min);
2414 vy = _mm512_min_ps(vy, vy_max);
2415 _mm512_mask_storeu_ps(y, vmask, vy);
2416 }
2417 }
2418
xnn_f32_vclamp_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2419 void xnn_f32_vclamp_ukernel__avx512f_x16(
2420 size_t n,
2421 const float* x,
2422 float* y,
2423 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
2424 {
2425 assert(n != 0);
2426 assert(n % sizeof(float) == 0);
2427 assert(x != NULL);
2428 assert(y != NULL);
2429
2430 const __m512 vy_min = _mm512_set1_ps(params->scalar.min);
2431 const __m512 vy_max = _mm512_set1_ps(params->scalar.max);
2432
2433 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2434 __m512 vacc0123456789ABCDEF = _mm512_loadu_ps(x);
2435 x += 16;
2436
2437 vacc0123456789ABCDEF = _mm512_max_ps(vacc0123456789ABCDEF, vy_min);
2438
2439 vacc0123456789ABCDEF = _mm512_min_ps(vacc0123456789ABCDEF, vy_max);
2440
2441 _mm512_storeu_ps(y, vacc0123456789ABCDEF);
2442 y += 16;
2443 }
2444 if XNN_UNLIKELY(n != 0) {
2445 assert(n >= 1 * sizeof(float));
2446 assert(n <= 15 * sizeof(float));
2447 // Prepare mask for valid 32-bit elements (depends on n).
2448 n >>= 2 /* log2(sizeof(float)) */;
2449 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2450
2451 __m512 vacc = _mm512_maskz_loadu_ps(vmask, x);
2452 vacc = _mm512_max_ps(vacc, vy_min);
2453 vacc = _mm512_min_ps(vacc, vy_max);
2454 _mm512_mask_storeu_ps(y, vmask, vacc);
2455 }
2456 }
2457
xnn_f32_velu_ukernel__avx512f_rr1_lut16_p3_perm_x64(size_t n,const float * x,float * y,const union xnn_f32_elu_params params[restrict XNN_MIN_ELEMENTS (1)])2458 void xnn_f32_velu_ukernel__avx512f_rr1_lut16_p3_perm_x64(
2459 size_t n,
2460 const float* x,
2461 float* y,
2462 const union xnn_f32_elu_params params[restrict XNN_MIN_ELEMENTS(1)])
2463 {
2464 assert(n != 0);
2465 assert(n % sizeof(float) == 0);
2466
2467 const __m512 vprescale = _mm512_set1_ps(params->avx512_rr1_lut16_p3.prescale);
2468 const __m512 valpha = _mm512_set1_ps(params->avx512_rr1_lut16_p3.alpha);
2469 const __m512 vbeta = _mm512_set1_ps(params->avx512_rr1_lut16_p3.beta);
2470 const __m512 vsat_cutoff = _mm512_set1_ps(params->avx512_rr1_lut16_p3.sat_cutoff);
2471 const __m512 vmagic_bias = _mm512_set1_ps(params->avx512_rr1_lut16_p3.magic_bias);
2472 const __m512 vlog2e = _mm512_set1_ps(params->avx512_rr1_lut16_p3.log2e);
2473 const __m512 vminus_ln2 = _mm512_set1_ps(params->avx512_rr1_lut16_p3.minus_ln2);
2474 const __m512 vc3 = _mm512_set1_ps(params->avx512_rr1_lut16_p3.c3);
2475 const __m512 vc2 = _mm512_set1_ps(params->avx512_rr1_lut16_p3.c2);
2476 const __m512i vtable = _mm512_load_si512(params->avx512_rr1_lut16_p3.table);
2477
2478 for (; n >= 64 * sizeof(float); n -= 64 * sizeof(float)) {
2479 __m512 vx0 = _mm512_loadu_ps(x);
2480 __m512 vx1 = _mm512_loadu_ps(x + 16);
2481 __m512 vx2 = _mm512_loadu_ps(x + 32);
2482 __m512 vx3 = _mm512_loadu_ps(x + 48);
2483 x += 64;
2484
2485 const __m512 vz0 = _mm512_max_ps(vsat_cutoff, _mm512_mul_ps(vx0, vprescale));
2486 const __m512 vz1 = _mm512_max_ps(vsat_cutoff, _mm512_mul_ps(vx1, vprescale));
2487 const __m512 vz2 = _mm512_max_ps(vsat_cutoff, _mm512_mul_ps(vx2, vprescale));
2488 const __m512 vz3 = _mm512_max_ps(vsat_cutoff, _mm512_mul_ps(vx3, vprescale));
2489
2490 __m512 vn0 = _mm512_fmadd_ps(vz0, vlog2e, vmagic_bias);
2491 __m512 vn1 = _mm512_fmadd_ps(vz1, vlog2e, vmagic_bias);
2492 __m512 vn2 = _mm512_fmadd_ps(vz2, vlog2e, vmagic_bias);
2493 __m512 vn3 = _mm512_fmadd_ps(vz3, vlog2e, vmagic_bias);
2494
2495 const __m512i ven0 = _mm512_slli_epi32(_mm512_castps_si512(vn0), 19);
2496 const __m512i vl0 = _mm512_permutexvar_epi32(_mm512_castps_si512(vn0), vtable);
2497 const __m512i ven1 = _mm512_slli_epi32(_mm512_castps_si512(vn1), 19);
2498 const __m512i vl1 = _mm512_permutexvar_epi32(_mm512_castps_si512(vn1), vtable);
2499 const __m512i ven2 = _mm512_slli_epi32(_mm512_castps_si512(vn2), 19);
2500 const __m512i vl2 = _mm512_permutexvar_epi32(_mm512_castps_si512(vn2), vtable);
2501 const __m512i ven3 = _mm512_slli_epi32(_mm512_castps_si512(vn3), 19);
2502 const __m512i vl3 = _mm512_permutexvar_epi32(_mm512_castps_si512(vn3), vtable);
2503
2504 __m512 vs0 = _mm512_castsi512_ps(_mm512_add_epi32(vl0, ven0));
2505 vn0 = _mm512_sub_ps(vn0, vmagic_bias);
2506 __m512 vs1 = _mm512_castsi512_ps(_mm512_add_epi32(vl1, ven1));
2507 vn1 = _mm512_sub_ps(vn1, vmagic_bias);
2508 __m512 vs2 = _mm512_castsi512_ps(_mm512_add_epi32(vl2, ven2));
2509 vn2 = _mm512_sub_ps(vn2, vmagic_bias);
2510 __m512 vs3 = _mm512_castsi512_ps(_mm512_add_epi32(vl3, ven3));
2511 vn3 = _mm512_sub_ps(vn3, vmagic_bias);
2512
2513 __m512 vt0 = _mm512_fmadd_ps(vn0, vminus_ln2, vz0);
2514 __m512 vt1 = _mm512_fmadd_ps(vn1, vminus_ln2, vz1);
2515 __m512 vt2 = _mm512_fmadd_ps(vn2, vminus_ln2, vz2);
2516 __m512 vt3 = _mm512_fmadd_ps(vn3, vminus_ln2, vz3);
2517
2518 __m512 vp0 = _mm512_fmadd_ps(vc3, vt0, vc2);
2519 __m512 vp1 = _mm512_fmadd_ps(vc3, vt1, vc2);
2520 __m512 vp2 = _mm512_fmadd_ps(vc3, vt2, vc2);
2521 __m512 vp3 = _mm512_fmadd_ps(vc3, vt3, vc2);
2522
2523 vp0 = _mm512_mul_ps(vp0, vt0);
2524 vt0 = _mm512_mul_ps(vt0, vs0);
2525 vp1 = _mm512_mul_ps(vp1, vt1);
2526 vt1 = _mm512_mul_ps(vt1, vs1);
2527 vp2 = _mm512_mul_ps(vp2, vt2);
2528 vt2 = _mm512_mul_ps(vt2, vs2);
2529 vp3 = _mm512_mul_ps(vp3, vt3);
2530 vt3 = _mm512_mul_ps(vt3, vs3);
2531
2532 vs0 = _mm512_fmsub_ps(vs0, valpha, valpha);
2533 vs1 = _mm512_fmsub_ps(vs1, valpha, valpha);
2534 vs2 = _mm512_fmsub_ps(vs2, valpha, valpha);
2535 vs3 = _mm512_fmsub_ps(vs3, valpha, valpha);
2536
2537 vp0 = _mm512_fmadd_ps(vp0, vt0, vt0);
2538 vp1 = _mm512_fmadd_ps(vp1, vt1, vt1);
2539 vp2 = _mm512_fmadd_ps(vp2, vt2, vt2);
2540 vp3 = _mm512_fmadd_ps(vp3, vt3, vt3);
2541
2542 const __m512 vzero = _mm512_setzero_ps();
2543 __m512 vy0 = _mm512_fmadd_ps(vp0, valpha, vs0);
2544 const __mmask16 vsign0 = _mm512_cmp_ps_mask(vx0, vzero, _CMP_NLT_US);
2545 __m512 vy1 = _mm512_fmadd_ps(vp1, valpha, vs1);
2546 const __mmask16 vsign1 = _mm512_cmp_ps_mask(vx1, vzero, _CMP_NLT_US);
2547 __m512 vy2 = _mm512_fmadd_ps(vp2, valpha, vs2);
2548 const __mmask16 vsign2 = _mm512_cmp_ps_mask(vx2, vzero, _CMP_NLT_US);
2549 __m512 vy3 = _mm512_fmadd_ps(vp3, valpha, vs3);
2550 const __mmask16 vsign3 = _mm512_cmp_ps_mask(vx3, vzero, _CMP_NLT_US);
2551
2552 vy0 = _mm512_mask_mul_ps(vy0, vsign0, vx0, vbeta);
2553 vy1 = _mm512_mask_mul_ps(vy1, vsign1, vx1, vbeta);
2554 vy2 = _mm512_mask_mul_ps(vy2, vsign2, vx2, vbeta);
2555 vy3 = _mm512_mask_mul_ps(vy3, vsign3, vx3, vbeta);
2556
2557 _mm512_storeu_ps(y, vy0);
2558 _mm512_storeu_ps(y + 16, vy1);
2559 _mm512_storeu_ps(y + 32, vy2);
2560 _mm512_storeu_ps(y + 48, vy3);
2561 y += 64;
2562 }
2563 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2564 __m512 vx = _mm512_loadu_ps(x);
2565 x += 16;
2566
2567 const __m512 vz = _mm512_max_ps(vsat_cutoff, _mm512_mul_ps(vx, vprescale));
2568 const __mmask16 vsign = _mm512_cmp_ps_mask(vx, _mm512_setzero_ps(), _CMP_NLT_US);
2569
2570 __m512 vn = _mm512_fmadd_ps(vz, vlog2e, vmagic_bias);
2571 const __m512i ven = _mm512_slli_epi32(_mm512_castps_si512(vn), 19);
2572 const __m512i vl = _mm512_permutexvar_epi32(_mm512_castps_si512(vn), vtable);
2573 __m512 vs = _mm512_castsi512_ps(_mm512_add_epi32(vl, ven));
2574 vn = _mm512_sub_ps(vn, vmagic_bias);
2575
2576 __m512 vt = _mm512_fmadd_ps(vn, vminus_ln2, vz);
2577
2578 __m512 vp = _mm512_fmadd_ps(vc3, vt, vc2);
2579 vp = _mm512_mul_ps(vp, vt);
2580
2581 vt = _mm512_mul_ps(vt, vs);
2582 vs = _mm512_fmsub_ps(vs, valpha, valpha);
2583 vp = _mm512_fmadd_ps(vp, vt, vt);
2584 __m512 vy = _mm512_fmadd_ps(vp, valpha, vs);
2585
2586 vy = _mm512_mask_mul_ps(vy, vsign, vx, vbeta);
2587
2588 _mm512_storeu_ps(y, vy);
2589 y += 16;
2590 }
2591 if XNN_UNLIKELY(n != 0) {
2592 assert(n >= 1 * sizeof(float));
2593 assert(n <= 15 * sizeof(float));
2594 // Prepare mask for valid 32-bit elements (depends on n).
2595 n >>= 2 /* log2(sizeof(float)) */;
2596 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2597
2598 __m512 vx = _mm512_maskz_loadu_ps(vmask, x);
2599
2600 const __m512 vz = _mm512_max_ps(vsat_cutoff, _mm512_mul_ps(vx, vprescale));
2601 const __mmask16 vsign = _mm512_cmp_ps_mask(vx, _mm512_setzero_ps(), _CMP_NLT_US);
2602
2603 __m512 vn = _mm512_fmadd_ps(vz, vlog2e, vmagic_bias);
2604 const __m512i ven = _mm512_slli_epi32(_mm512_castps_si512(vn), 19);
2605 const __m512i vl = _mm512_permutexvar_epi32(_mm512_castps_si512(vn), vtable);
2606 __m512 vs = _mm512_castsi512_ps(_mm512_add_epi32(vl, ven));
2607 vn = _mm512_sub_ps(vn, vmagic_bias);
2608
2609 __m512 vt = _mm512_fmadd_ps(vn, vminus_ln2, vz);
2610
2611 __m512 vp = _mm512_fmadd_ps(vc3, vt, vc2);
2612 vp = _mm512_mul_ps(vp, vt);
2613
2614 vt = _mm512_mul_ps(vt, vs);
2615 vs = _mm512_fmsub_ps(vs, valpha, valpha);
2616 vp = _mm512_fmadd_ps(vp, vt, vt);
2617 __m512 vy = _mm512_fmadd_ps(vp, valpha, vs);
2618
2619 vy = _mm512_mask_mul_ps(vy, vsign, vx, vbeta);
2620
2621 _mm512_mask_storeu_ps(y, vmask, vy);
2622 }
2623 }
2624
xnn_f32_vhswish_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_hswish_params params[restrict XNN_MIN_ELEMENTS (1)])2625 void xnn_f32_vhswish_ukernel__avx512f_x16(
2626 size_t n,
2627 const float* x,
2628 float* y,
2629 const union xnn_f32_hswish_params params[restrict XNN_MIN_ELEMENTS(1)])
2630 {
2631 assert(n != 0);
2632 assert(n % sizeof(float) == 0);
2633
2634 const __m512 vsixth = _mm512_set1_ps(params->avx512.sixth);
2635 const __m512 vhalf = _mm512_set1_ps(params->avx512.half);
2636 const __m512 vone = _mm512_set1_ps(params->avx512.one);
2637 const __m512 vzero = _mm512_setzero_ps();
2638
2639 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2640 const __m512 vx = _mm512_loadu_ps(x);
2641 x += 16;
2642 __m512 vacc = _mm512_fmadd_ps(vx, vsixth, vhalf);
2643 vacc = _mm512_max_ps(vacc, vzero);
2644 vacc = _mm512_min_ps(vacc, vone);
2645 vacc = _mm512_mul_ps(vacc, vx);
2646 _mm512_storeu_ps(y, vacc);
2647 y += 16;
2648 }
2649 if XNN_UNLIKELY(n != 0) {
2650 assert(n >= 1 * sizeof(float));
2651 assert(n <= 15 * sizeof(float));
2652 // Prepare mask for valid 32-bit elements (depends on n).
2653 n >>= 2 /* log2(sizeof(float)) */;
2654 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2655
2656 const __m512 vx = _mm512_maskz_loadu_ps(vmask, x);
2657 __m512 vacc = _mm512_fmadd_ps(vx, vsixth, vhalf);
2658 vacc = _mm512_max_ps(vacc, vzero);
2659 vacc = _mm512_min_ps(vacc, vone);
2660 vacc = _mm512_mul_ps(vacc, vx);
2661 _mm512_mask_storeu_ps(y, vmask, vacc);
2662 }
2663 }
2664
xnn_f32_vlrelu_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_lrelu_params params[restrict XNN_MIN_ELEMENTS (1)])2665 void xnn_f32_vlrelu_ukernel__avx512f_x16(
2666 size_t n,
2667 const float* x,
2668 float* y,
2669 const union xnn_f32_lrelu_params params[restrict XNN_MIN_ELEMENTS(1)])
2670 {
2671 assert(n != 0);
2672 assert(n % sizeof(float) == 0);
2673
2674 const __m512 vslope = _mm512_set1_ps(params->scalar.slope);
2675 const __m512 vzero = _mm512_setzero_ps();
2676
2677 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2678 __m512 vacc0123456789ABCDEF = _mm512_loadu_ps(x);
2679 x += 16;
2680
2681 const __mmask16 vsign0123456789ABCDEF = _mm512_cmp_ps_mask(vacc0123456789ABCDEF, vzero, _CMP_LT_OQ);
2682
2683 vacc0123456789ABCDEF = _mm512_mask_mul_ps(vacc0123456789ABCDEF, vsign0123456789ABCDEF, vacc0123456789ABCDEF, vslope);
2684
2685 _mm512_storeu_ps(y, vacc0123456789ABCDEF);
2686 y += 16;
2687 }
2688 if XNN_UNLIKELY(n != 0) {
2689 assert(n >= 1 * sizeof(float));
2690 assert(n <= 15 * sizeof(float));
2691 // Prepare mask for valid 32-bit elements (depends on n).
2692 n >>= 2 /* log2(sizeof(float)) */;
2693 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2694
2695 __m512 vacc = _mm512_maskz_loadu_ps(vmask, x);
2696 const __mmask16 vsign = _mm512_mask_cmp_ps_mask(vmask, vacc, vzero, _CMP_LT_OQ);
2697 vacc = _mm512_mask_mul_ps(vacc, vsign, vacc, vslope);
2698 _mm512_mask_storeu_ps(y, vmask, vacc);
2699 }
2700 }
2701
xnn_f32_vrndd_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_rnd_params params[restrict XNN_MIN_ELEMENTS (1)])2702 void xnn_f32_vrndd_ukernel__avx512f_x16(
2703 size_t n,
2704 const float* x,
2705 float* y,
2706 const union xnn_f32_rnd_params params[restrict XNN_MIN_ELEMENTS(1)])
2707 {
2708 assert(n != 0);
2709 assert(n % sizeof(float) == 0);
2710
2711 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2712 const __m512 vx0123456789ABCDEF = _mm512_loadu_ps(x);
2713 x += 16;
2714
2715 const __m512 vy0123456789ABCDEF = _mm512_roundscale_ps(vx0123456789ABCDEF, _MM_FROUND_TO_NEG_INF);
2716
2717 _mm512_storeu_ps(y, vy0123456789ABCDEF);
2718 y += 16;
2719 }
2720 if XNN_UNLIKELY(n != 0) {
2721 assert(n >= 1 * sizeof(float));
2722 assert(n <= 15 * sizeof(float));
2723 // Prepare mask for valid 32-bit elements (depends on n).
2724 n >>= 2 /* log2(sizeof(float)) */;
2725 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2726
2727 const __m512 vx = _mm512_maskz_loadu_ps(vmask, x);
2728 const __m512 vy = _mm512_maskz_roundscale_ps(vmask, vx, _MM_FROUND_TO_NEG_INF);
2729 _mm512_mask_storeu_ps(y, vmask, vy);
2730 }
2731 }
2732
xnn_f32_vrndne_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_rnd_params params[restrict XNN_MIN_ELEMENTS (1)])2733 void xnn_f32_vrndne_ukernel__avx512f_x16(
2734 size_t n,
2735 const float* x,
2736 float* y,
2737 const union xnn_f32_rnd_params params[restrict XNN_MIN_ELEMENTS(1)])
2738 {
2739 assert(n != 0);
2740 assert(n % sizeof(float) == 0);
2741
2742 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2743 const __m512 vx0123456789ABCDEF = _mm512_loadu_ps(x);
2744 x += 16;
2745
2746 const __m512 vy0123456789ABCDEF = _mm512_roundscale_ps(vx0123456789ABCDEF, _MM_FROUND_TO_NEAREST_INT);
2747
2748 _mm512_storeu_ps(y, vy0123456789ABCDEF);
2749 y += 16;
2750 }
2751 if XNN_UNLIKELY(n != 0) {
2752 assert(n >= 1 * sizeof(float));
2753 assert(n <= 15 * sizeof(float));
2754 // Prepare mask for valid 32-bit elements (depends on n).
2755 n >>= 2 /* log2(sizeof(float)) */;
2756 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2757
2758 const __m512 vx = _mm512_maskz_loadu_ps(vmask, x);
2759 const __m512 vy = _mm512_maskz_roundscale_ps(vmask, vx, _MM_FROUND_TO_NEAREST_INT);
2760 _mm512_mask_storeu_ps(y, vmask, vy);
2761 }
2762 }
2763
xnn_f32_vrndu_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_rnd_params params[restrict XNN_MIN_ELEMENTS (1)])2764 void xnn_f32_vrndu_ukernel__avx512f_x16(
2765 size_t n,
2766 const float* x,
2767 float* y,
2768 const union xnn_f32_rnd_params params[restrict XNN_MIN_ELEMENTS(1)])
2769 {
2770 assert(n != 0);
2771 assert(n % sizeof(float) == 0);
2772
2773 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2774 const __m512 vx0123456789ABCDEF = _mm512_loadu_ps(x);
2775 x += 16;
2776
2777 const __m512 vy0123456789ABCDEF = _mm512_roundscale_ps(vx0123456789ABCDEF, _MM_FROUND_TO_POS_INF);
2778
2779 _mm512_storeu_ps(y, vy0123456789ABCDEF);
2780 y += 16;
2781 }
2782 if XNN_UNLIKELY(n != 0) {
2783 assert(n >= 1 * sizeof(float));
2784 assert(n <= 15 * sizeof(float));
2785 // Prepare mask for valid 32-bit elements (depends on n).
2786 n >>= 2 /* log2(sizeof(float)) */;
2787 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2788
2789 const __m512 vx = _mm512_maskz_loadu_ps(vmask, x);
2790 const __m512 vy = _mm512_maskz_roundscale_ps(vmask, vx, _MM_FROUND_TO_POS_INF);
2791 _mm512_mask_storeu_ps(y, vmask, vy);
2792 }
2793 }
2794
xnn_f32_vrndz_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_rnd_params params[restrict XNN_MIN_ELEMENTS (1)])2795 void xnn_f32_vrndz_ukernel__avx512f_x16(
2796 size_t n,
2797 const float* x,
2798 float* y,
2799 const union xnn_f32_rnd_params params[restrict XNN_MIN_ELEMENTS(1)])
2800 {
2801 assert(n != 0);
2802 assert(n % sizeof(float) == 0);
2803
2804 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2805 const __m512 vx0123456789ABCDEF = _mm512_loadu_ps(x);
2806 x += 16;
2807
2808 const __m512 vy0123456789ABCDEF = _mm512_roundscale_ps(vx0123456789ABCDEF, _MM_FROUND_TO_ZERO);
2809
2810 _mm512_storeu_ps(y, vy0123456789ABCDEF);
2811 y += 16;
2812 }
2813 if XNN_UNLIKELY(n != 0) {
2814 assert(n >= 1 * sizeof(float));
2815 assert(n <= 15 * sizeof(float));
2816 // Prepare mask for valid 32-bit elements (depends on n).
2817 n >>= 2 /* log2(sizeof(float)) */;
2818 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2819
2820 const __m512 vx = _mm512_maskz_loadu_ps(vmask, x);
2821 const __m512 vy = _mm512_maskz_roundscale_ps(vmask, vx, _MM_FROUND_TO_ZERO);
2822 _mm512_mask_storeu_ps(y, vmask, vy);
2823 }
2824 }
2825
xnn_f32_vsigmoid_ukernel__avx512f_rr2_lut32_p2_perm2_scalef_div_x64(size_t n,const float * x,float * y,const union xnn_f32_sigmoid_params params[restrict XNN_MIN_ELEMENTS (1)])2826 void xnn_f32_vsigmoid_ukernel__avx512f_rr2_lut32_p2_perm2_scalef_div_x64(
2827 size_t n,
2828 const float* x,
2829 float* y,
2830 const union xnn_f32_sigmoid_params params[restrict XNN_MIN_ELEMENTS(1)])
2831 {
2832 assert(n % sizeof(float) == 0);
2833
2834 const __m512i vsign_mask = _mm512_set1_epi32((int) params->avx512_rr2_lut32_p2.sign_mask);
2835 const __m512 vmagic_bias = _mm512_set1_ps(params->avx512_rr2_lut32_p2.magic_bias);
2836 const __m512 vlog2e = _mm512_set1_ps(params->avx512_rr2_lut32_p2.log2e);
2837 const __m512 vtable_lo = _mm512_load_ps(params->avx512_rr2_lut32_p2.table_lo);
2838 const __m512 vtable_hi = _mm512_load_ps(params->avx512_rr2_lut32_p2.table_hi);
2839 const __m512 vminus_ln2_hi = _mm512_set1_ps(params->avx512_rr2_lut32_p2.minus_ln2_hi);
2840 const __m512 vminus_ln2_lo = _mm512_set1_ps(params->avx512_rr2_lut32_p2.minus_ln2_lo);
2841 const __m512 vc2 = _mm512_set1_ps(params->avx512_rr2_lut32_p2.c2);
2842 const __m512 vc1 = _mm512_set1_ps(params->avx512_rr2_lut32_p2.c1);
2843 const __m512 vone = _mm512_set1_ps(params->avx512_rr2_lut32_p2.one);
2844
2845 for (; n >= 64 * sizeof(float); n -= 64 * sizeof(float)) {
2846 const __m512 vx0 = _mm512_loadu_ps(x);
2847 const __m512 vx1 = _mm512_loadu_ps(x + 16);
2848 const __m512 vx2 = _mm512_loadu_ps(x + 32);
2849 const __m512 vx3 = _mm512_loadu_ps(x + 48);
2850 x += 64;
2851
2852 const __m512 vz0 = _mm512_castsi512_ps(_mm512_or_epi32(_mm512_castps_si512(vx0), vsign_mask));
2853 const __m512 vz1 = _mm512_castsi512_ps(_mm512_or_epi32(_mm512_castps_si512(vx1), vsign_mask));
2854 const __m512 vz2 = _mm512_castsi512_ps(_mm512_or_epi32(_mm512_castps_si512(vx2), vsign_mask));
2855 const __m512 vz3 = _mm512_castsi512_ps(_mm512_or_epi32(_mm512_castps_si512(vx3), vsign_mask));
2856
2857 __m512 vn0 = _mm512_fmadd_ps(vz0, vlog2e, vmagic_bias);
2858 __m512 vn1 = _mm512_fmadd_ps(vz1, vlog2e, vmagic_bias);
2859 __m512 vn2 = _mm512_fmadd_ps(vz2, vlog2e, vmagic_bias);
2860 __m512 vn3 = _mm512_fmadd_ps(vz3, vlog2e, vmagic_bias);
2861
2862 const __m512 vl0 = _mm512_permutex2var_ps(vtable_lo, _mm512_castps_si512(vn0), vtable_hi);
2863 const __m512 vl1 = _mm512_permutex2var_ps(vtable_lo, _mm512_castps_si512(vn1), vtable_hi);
2864 const __m512 vl2 = _mm512_permutex2var_ps(vtable_lo, _mm512_castps_si512(vn2), vtable_hi);
2865 const __m512 vl3 = _mm512_permutex2var_ps(vtable_lo, _mm512_castps_si512(vn3), vtable_hi);
2866
2867 vn0 = _mm512_sub_ps(vn0, vmagic_bias);
2868 vn1 = _mm512_sub_ps(vn1, vmagic_bias);
2869 vn2 = _mm512_sub_ps(vn2, vmagic_bias);
2870 vn3 = _mm512_sub_ps(vn3, vmagic_bias);
2871
2872 __m512 vt0 = _mm512_fmadd_ps(vn0, vminus_ln2_hi, vz0);
2873 __m512 vt1 = _mm512_fmadd_ps(vn1, vminus_ln2_hi, vz1);
2874 __m512 vt2 = _mm512_fmadd_ps(vn2, vminus_ln2_hi, vz2);
2875 __m512 vt3 = _mm512_fmadd_ps(vn3, vminus_ln2_hi, vz3);
2876
2877 vt0 = _mm512_fmadd_ps(vn0, vminus_ln2_lo, vt0);
2878 vt1 = _mm512_fmadd_ps(vn1, vminus_ln2_lo, vt1);
2879 vt2 = _mm512_fmadd_ps(vn2, vminus_ln2_lo, vt2);
2880 vt3 = _mm512_fmadd_ps(vn3, vminus_ln2_lo, vt3);
2881
2882 __m512 vp0 = _mm512_fmadd_ps(vt0, vc2, vc1);
2883 __m512 vp1 = _mm512_fmadd_ps(vt1, vc2, vc1);
2884 __m512 vp2 = _mm512_fmadd_ps(vt2, vc2, vc1);
2885 __m512 vp3 = _mm512_fmadd_ps(vt3, vc2, vc1);
2886
2887 vt0 = _mm512_mul_ps(vt0, vl0);
2888 vt1 = _mm512_mul_ps(vt1, vl1);
2889 vt2 = _mm512_mul_ps(vt2, vl2);
2890 vt3 = _mm512_mul_ps(vt3, vl3);
2891
2892 vp0 = _mm512_fmadd_ps(vt0, vp0, vl0);
2893 vp1 = _mm512_fmadd_ps(vt1, vp1, vl1);
2894 vp2 = _mm512_fmadd_ps(vt2, vp2, vl2);
2895 vp3 = _mm512_fmadd_ps(vt3, vp3, vl3);
2896
2897 const __m512 ve0 = _mm512_scalef_ps(vp0, vn0);
2898 const __m512 ve1 = _mm512_scalef_ps(vp1, vn1);
2899 const __m512 ve2 = _mm512_scalef_ps(vp2, vn2);
2900 const __m512 ve3 = _mm512_scalef_ps(vp3, vn3);
2901
2902 const __m512 vd0 = _mm512_add_ps(ve0, vone);
2903 const __m512 vd1 = _mm512_add_ps(ve1, vone);
2904 const __m512 vd2 = _mm512_add_ps(ve2, vone);
2905 const __m512 vd3 = _mm512_add_ps(ve3, vone);
2906
2907 __m512 vf0 = _mm512_div_ps(ve0, vd0);
2908 __m512 vf1 = _mm512_div_ps(ve1, vd1);
2909 __m512 vf2 = _mm512_div_ps(ve2, vd2);
2910 __m512 vf3 = _mm512_div_ps(ve3, vd3);
2911
2912 vf0 = _mm512_mask_sub_ps(vf0, _mm512_testn_epi32_mask(_mm512_castps_si512(vx0), vsign_mask), vone, vf0);
2913 vf1 = _mm512_mask_sub_ps(vf1, _mm512_testn_epi32_mask(_mm512_castps_si512(vx1), vsign_mask), vone, vf1);
2914 vf2 = _mm512_mask_sub_ps(vf2, _mm512_testn_epi32_mask(_mm512_castps_si512(vx2), vsign_mask), vone, vf2);
2915 vf3 = _mm512_mask_sub_ps(vf3, _mm512_testn_epi32_mask(_mm512_castps_si512(vx3), vsign_mask), vone, vf3);
2916
2917 _mm512_storeu_ps(y, vf0);
2918 _mm512_storeu_ps(y + 16, vf1);
2919 _mm512_storeu_ps(y + 32, vf2);
2920 _mm512_storeu_ps(y + 48, vf3);
2921 y += 64;
2922 }
2923 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2924 const __m512 vx = _mm512_loadu_ps(x);
2925 x += 16;
2926
2927 const __m512 vz = _mm512_castsi512_ps(_mm512_or_epi32(_mm512_castps_si512(vx), vsign_mask));
2928
2929 __m512 vn = _mm512_fmadd_ps(vz, vlog2e, vmagic_bias);
2930 const __m512 vl = _mm512_permutex2var_ps(vtable_lo, _mm512_castps_si512(vn), vtable_hi);
2931 vn = _mm512_sub_ps(vn, vmagic_bias);
2932
2933 __m512 vt = _mm512_fmadd_ps(vn, vminus_ln2_hi, vz);
2934 vt = _mm512_fmadd_ps(vn, vminus_ln2_lo, vt);
2935
2936 __m512 vp = _mm512_fmadd_ps(vt, vc2, vc1);
2937 vt = _mm512_mul_ps(vt, vl);
2938 vp = _mm512_fmadd_ps(vt, vp, vl);
2939
2940 const __m512 ve = _mm512_scalef_ps(vp, vn);
2941 const __m512 vd = _mm512_add_ps(ve, vone);
2942
2943 __m512 vf = _mm512_div_ps(ve, vd);
2944
2945 vf = _mm512_mask_sub_ps(vf, _mm512_testn_epi32_mask(_mm512_castps_si512(vx), vsign_mask), vone, vf);
2946
2947 _mm512_storeu_ps(y, vf);
2948 y += 16;
2949 }
2950 if XNN_UNLIKELY(n != 0) {
2951 assert(n >= 1 * sizeof(float));
2952 assert(n <= 15 * sizeof(float));
2953
2954 // Prepare mask for valid 32-bit elements (depends on n).
2955 n >>= 2 /* log2(sizeof(float)) */;
2956 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2957
2958 const __m512 vx = _mm512_maskz_loadu_ps(vmask, x);
2959 const __m512 vz = _mm512_castsi512_ps(_mm512_or_epi32(_mm512_castps_si512(vx), vsign_mask));
2960
2961 __m512 vn = _mm512_fmadd_ps(vz, vlog2e, vmagic_bias);
2962 const __m512 vl = _mm512_permutex2var_ps(vtable_lo, _mm512_castps_si512(vn), vtable_hi);
2963 vn = _mm512_sub_ps(vn, vmagic_bias);
2964
2965 __m512 vt = _mm512_fmadd_ps(vn, vminus_ln2_hi, vz);
2966 vt = _mm512_fmadd_ps(vn, vminus_ln2_lo, vt);
2967
2968 __m512 vp = _mm512_fmadd_ps(vt, vc2, vc1);
2969 vt = _mm512_mul_ps(vt, vl);
2970 vp = _mm512_fmadd_ps(vt, vp, vl);
2971
2972 const __m512 ve = _mm512_scalef_ps(vp, vn);
2973 const __m512 vd = _mm512_add_ps(ve, vone);
2974
2975 __m512 vf = _mm512_div_ps(ve, vd);
2976
2977 vf = _mm512_mask_sub_ps(vf, _mm512_testn_epi32_mask(_mm512_castps_si512(vx), vsign_mask), vone, vf);
2978
2979 _mm512_mask_storeu_ps(y, vmask, vf);
2980 }
2981 }
2982
xnn_f32_vabs_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_abs_params params[restrict XNN_MIN_ELEMENTS (1)])2983 void xnn_f32_vabs_ukernel__avx512f_x16(
2984 size_t n,
2985 const float* x,
2986 float* y,
2987 const union xnn_f32_abs_params params[restrict XNN_MIN_ELEMENTS(1)])
2988 {
2989 assert(n != 0);
2990 assert(n % sizeof(float) == 0);
2991 assert(x != NULL);
2992 assert(y != NULL);
2993
2994 const __m512i vnonsign_mask = _mm512_set1_epi32((int) params->avx512.nonsign_mask);
2995 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
2996 const __m512i vx0123456789ABCDEF = _mm512_loadu_si512(x);
2997 x += 16;
2998
2999 const __m512i vy0123456789ABCDEF = _mm512_and_epi32(vx0123456789ABCDEF, vnonsign_mask);
3000
3001 _mm512_storeu_si512(y, vy0123456789ABCDEF);
3002 y += 16;
3003 }
3004 if XNN_UNLIKELY(n != 0) {
3005 assert(n >= 1 * sizeof(float));
3006 assert(n <= 15 * sizeof(float));
3007 // Prepare mask for valid 32-bit elements (depends on n).
3008 n >>= 2 /* log2(sizeof(float)) */;
3009 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
3010
3011 const __m512i vx = _mm512_maskz_loadu_epi32(vmask, x);
3012 const __m512i vy = _mm512_and_epi32(vx, vnonsign_mask);
3013 _mm512_mask_storeu_epi32(y, vmask, vy);
3014 }
3015 }
3016
xnn_f32_vneg_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_neg_params params[restrict XNN_MIN_ELEMENTS (1)])3017 void xnn_f32_vneg_ukernel__avx512f_x16(
3018 size_t n,
3019 const float* x,
3020 float* y,
3021 const union xnn_f32_neg_params params[restrict XNN_MIN_ELEMENTS(1)])
3022 {
3023 assert(n != 0);
3024 assert(n % sizeof(float) == 0);
3025 assert(x != NULL);
3026 assert(y != NULL);
3027
3028 const __m512i vsign_mask = _mm512_set1_epi32((int) params->avx512.sign_mask);
3029 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
3030 const __m512i vx0123456789ABCDEF = _mm512_loadu_si512(x);
3031 x += 16;
3032
3033 const __m512i vy0123456789ABCDEF = _mm512_xor_epi32(vx0123456789ABCDEF, vsign_mask);
3034
3035 _mm512_storeu_si512(y, vy0123456789ABCDEF);
3036 y += 16;
3037 }
3038 if XNN_UNLIKELY(n != 0) {
3039 assert(n >= 1 * sizeof(float));
3040 assert(n <= 15 * sizeof(float));
3041 // Prepare mask for valid 32-bit elements (depends on n).
3042 n >>= 2 /* log2(sizeof(float)) */;
3043 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
3044
3045 const __m512i vx = _mm512_maskz_loadu_epi32(vmask, x);
3046 const __m512i vy = _mm512_xor_epi32(vx, vsign_mask);
3047 _mm512_mask_storeu_epi32(y, vmask, vy);
3048 }
3049 }
3050
xnn_f32_vsqr_ukernel__avx512f_x16(size_t n,const float * x,float * y,const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS (1)])3051 void xnn_f32_vsqr_ukernel__avx512f_x16(
3052 size_t n,
3053 const float* x,
3054 float* y,
3055 const union xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
3056 {
3057 assert(n != 0);
3058 assert(n % sizeof(float) == 0);
3059 assert(x != NULL);
3060 assert(y != NULL);
3061
3062 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
3063 const __m512 vx0123456789ABCDEF = _mm512_loadu_ps(x);
3064 x += 16;
3065
3066 const __m512 vy0123456789ABCDEF = _mm512_mul_ps(vx0123456789ABCDEF, vx0123456789ABCDEF);
3067
3068 _mm512_storeu_ps(y, vy0123456789ABCDEF);
3069 y += 16;
3070 }
3071 if XNN_UNLIKELY(n != 0) {
3072 assert(n >= 1 * sizeof(float));
3073 assert(n <= 15 * sizeof(float));
3074 // Prepare mask for valid 32-bit elements (depends on n).
3075 n >>= 2 /* log2(sizeof(float)) */;
3076 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
3077
3078 const __m512 vx = _mm512_maskz_loadu_ps(vmask, x);
3079 const __m512 vy = _mm512_mul_ps(vx, vx);
3080 _mm512_mask_storeu_ps(y, vmask, vy);
3081 }
3082 }
3083