1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10 #include <qnnpack/math.h>
11 #include <stdint.h>
12
13 // Legend:
14 // dq: Design-time Quantization
15 // rq: Run-time Quantization
16
pytorch_pack_q8gemm_wdq(size_t nc,size_t kc,uint32_t nr,uint32_t np,uint32_t kr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)17 static inline void pytorch_pack_q8gemm_wdq(
18 size_t nc,
19 size_t kc,
20 uint32_t nr,
21 uint32_t np,
22 uint32_t kr,
23 uint8_t izp,
24 uint8_t kzp,
25 const uint8_t* k,
26 const int32_t* b,
27 void* packed_w) {
28 const int32_t boff = (int32_t)kc * (int32_t)izp * (int32_t)kzp;
29 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
30 const size_t nr_block_size = min(nc - nr_block_start, nr);
31 int32_t* packed_b = (int32_t*)packed_w;
32 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
33 nr_block_offset++) {
34 *((int32_t*)packed_w) = b ? b[nr_block_start + nr_block_offset] + boff : 0.0f;
35 packed_w = (void*)((uintptr_t)packed_w + sizeof(int32_t));
36 }
37 packed_w =
38 (void*)((uintptr_t)packed_w + (nr - nr_block_size) * sizeof(int32_t));
39 for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
40 const size_t kr_block_size = min(kc - kr_block_start, kr);
41 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
42 nr_block_offset++) {
43 int32_t ksum = 0;
44 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
45 kr_block_offset++) {
46 const uint8_t kv =
47 k[(nr_block_start + nr_block_offset) * kc +
48 (kr_block_start + kr_block_offset)];
49 ksum += (int32_t)kv;
50 *((uint8_t*)packed_w) = kv;
51 packed_w = (void*)((uintptr_t)packed_w + sizeof(uint8_t));
52 }
53 packed_b[nr_block_offset] -= ksum * (int32_t)izp;
54 packed_w =
55 (void*)((uintptr_t)packed_w + (kr - kr_block_size) * sizeof(uint8_t));
56 }
57 packed_w =
58 (void*)((uintptr_t)packed_w + ((nr - nr_block_size) & (np - 1)) * kr * sizeof(uint8_t));
59 }
60 }
61 }
62
63 // NB: We use the same packing function for both dynamic quantization
64 // and runtime quantization for linear.
65 // This means that dynamic mode will suffer some perf because of the branching
66 // introduced due to `if(kzp!=0)` however, that should not be too significant.
pytorch_pack_q8gemm_wrq(const size_t nc,const size_t kc,const uint32_t nr,const uint32_t np,const uint32_t kr,const uint8_t * const k,const int32_t * const b,const uint8_t * const kzp,void * const packed_w)67 static inline void pytorch_pack_q8gemm_wrq(
68 const size_t nc,
69 const size_t kc,
70 const uint32_t nr,
71 const uint32_t np,
72 const uint32_t kr,
73 const uint8_t* const k,
74 const int32_t* const b,
75 const uint8_t* const kzp,
76 void* const packed_w) {
77 union {
78 void* const as_void_ptr;
79 uint8_t* as_uint8_ptr;
80 int32_t* as_int32_ptr;
81 } packed = {packed_w};
82
83 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
84 const size_t nr_block_size = min(nc - nr_block_start, nr);
85 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
86 nr_block_offset++) {
87 *(packed.as_int32_ptr++) = b ? b[nr_block_start + nr_block_offset] : 0;
88 }
89 packed.as_int32_ptr += (nr - nr_block_size);
90 for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
91 const size_t kr_block_size = min(kc - kr_block_start, kr);
92 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
93 nr_block_offset++) {
94 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
95 kr_block_offset++) {
96 const uint8_t kv =
97 k[(nr_block_start + nr_block_offset) * kc +
98 (kr_block_start + kr_block_offset)];
99 *(packed.as_uint8_ptr++) = kv;
100 }
101 // Weights need to be prepacked with the zero points, in their tail space
102 // where packed blocks are not multiple of input sizes
103 // e.g for ukernels with kr=2 and k is 3 then the second block must be
104 // padded with zero point. This is because when subtracting with zero point
105 // we just get zero for the padded value, which is what we want.
106 if (kzp != 0) {
107 for (size_t kr_block_offset = 0; kr_block_offset < (kr - kr_block_size);
108 kr_block_offset++) {
109 const uint8_t kv =
110 kzp[(nr_block_start + nr_block_offset)];
111 *(packed.as_uint8_ptr++) = kv;
112 }
113 } else {
114 packed.as_uint8_ptr += (kr - kr_block_size);
115 }
116 }
117 if (kzp != 0) {
118 // This part fills the packed weights with zero points for output channels
119 // when they are not divisible by nr blocking parameter.
120 // This is needed because in some kernels, sse2 ones, it relies on this
121 // to produce zero as a result of subtracting zero point from weight value.
122 size_t remaining_nr_blocks = ((nr - nr_block_size) & (np - 1));
123 for (size_t nr_block_offset = 0; nr_block_offset < remaining_nr_blocks;
124 nr_block_offset++) {
125 for (size_t kr_block_offset = 0; kr_block_offset < kr;
126 kr_block_offset++) {
127 const uint8_t kv =
128 kzp[(nr_block_start + nr_block_size + nr_block_offset)];
129 *(packed.as_uint8_ptr++) = kv;
130 }
131 }
132 } else {
133 packed.as_uint8_ptr += ((nr - nr_block_size) & (np - 1)) * kr;
134 }
135 }
136 }
137 }
138
pytorch_pack_q8conv_wdq(size_t n,size_t ks,size_t kc,uint32_t nr,uint32_t kr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)139 static inline void pytorch_pack_q8conv_wdq(
140 size_t n,
141 size_t ks,
142 size_t kc,
143 uint32_t nr,
144 uint32_t kr,
145 uint8_t izp,
146 uint8_t kzp,
147 const uint8_t* k,
148 const int32_t* b,
149 void* packed_w) {
150 const int32_t boff = (int32_t)ks * (int32_t)kc * (int32_t)izp * (int32_t)kzp;
151 for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) {
152 const size_t nr_block_size = min(n - nr_block_start, nr);
153 int32_t* packed_b = (int32_t*)packed_w;
154 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
155 nr_block_offset++) {
156 *((int32_t*)packed_w) = b ? b[nr_block_start + nr_block_offset] + boff : 0.0f;
157 packed_w = (void*)((uintptr_t)packed_w + sizeof(int32_t));
158 }
159 packed_w =
160 (void*)((uintptr_t)packed_w + (nr - nr_block_size) * sizeof(int32_t));
161 for (size_t ki = 0; ki < ks; ki++) {
162 for (size_t kr_block_start = 0; kr_block_start < kc;
163 kr_block_start += kr) {
164 const size_t kr_block_size = min(kc - kr_block_start, kr);
165 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
166 nr_block_offset++) {
167 int32_t ksum = 0;
168 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
169 kr_block_offset++) {
170 const uint8_t kv =
171 k[((nr_block_start + nr_block_offset) * ks + ki) * kc +
172 (kr_block_start + kr_block_offset)];
173 ksum += (int32_t)kv;
174 *((uint8_t*)packed_w) = kv;
175 packed_w = (void*)((uintptr_t)packed_w + sizeof(uint8_t));
176 }
177 packed_b[nr_block_offset] -= ksum * (int32_t)izp;
178 packed_w =
179 (void*)((uintptr_t)packed_w + (kr - kr_block_size) * sizeof(uint8_t));
180 }
181 packed_w =
182 (void*)((uintptr_t)packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
183 }
184 }
185 }
186 }
187
pytorch_pack_q8conv_wrq(const size_t n,const size_t ks,const size_t kc,const uint32_t nr,const uint32_t kr,const uint8_t * const k,const int32_t * const b,const uint8_t * const kzp,void * const packed_w)188 static inline void pytorch_pack_q8conv_wrq(
189 const size_t n,
190 const size_t ks,
191 const size_t kc,
192 const uint32_t nr,
193 const uint32_t kr,
194 const uint8_t* const k,
195 const int32_t* const b,
196 const uint8_t* const kzp,
197 void* const packed_w) {
198 union {
199 void* const as_void_ptr;
200 uint8_t* as_uint8_ptr;
201 int32_t* as_int32_ptr;
202 } packed = {packed_w};
203
204 for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) {
205 const size_t nr_block_size = min(n - nr_block_start, nr);
206 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
207 nr_block_offset++) {
208 *(packed.as_int32_ptr++) = b ? b[nr_block_start + nr_block_offset] : 0.0f;
209 }
210 packed.as_int32_ptr += (nr - nr_block_size);
211 for (size_t ki = 0; ki < ks; ki++) {
212 for (size_t kr_block_start = 0; kr_block_start < kc;
213 kr_block_start += kr) {
214 const size_t kr_block_size = min(kc - kr_block_start, kr);
215 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
216 nr_block_offset++) {
217 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
218 kr_block_offset++) {
219 const uint8_t kv =
220 k[((nr_block_start + nr_block_offset) * ks + ki) * kc +
221 (kr_block_start + kr_block_offset)];
222 *(packed.as_uint8_ptr++) = kv;
223 }
224 // Weights need to be prepacked with the zero points, in their tail space
225 // where packed blocks are not multiple of input sizes
226 // e.g for ukernels with kr=2 and k is 3 then the second block must be
227 // padded with zero point. This is because when subtracting with zero point
228 // we just get zero for the padded value, which is what we want.
229 if (kzp != 0) {
230 for (size_t kr_block_offset = 0; kr_block_offset < (kr - kr_block_size);
231 kr_block_offset++) {
232 const uint8_t kv =
233 kzp[(nr_block_start + nr_block_offset)];
234 *(packed.as_uint8_ptr++) = kv;
235 }
236 } else {
237 packed.as_uint8_ptr += (kr - kr_block_size);
238 }
239 }
240 if (kzp != 0) {
241 // This part fills the packed wights with zero points for output channels
242 // when they are not divisible by nr blocking parameter.
243 // In that case
244 for (size_t nr_block_offset = 0; nr_block_offset < (nr - nr_block_size);
245 nr_block_offset++) {
246 for (size_t kr_block_offset = 0; kr_block_offset < kr;
247 kr_block_offset++) {
248 const uint8_t kv =
249 kzp[(nr_block_start + nr_block_size + nr_block_offset)];
250 *(packed.as_uint8_ptr++) = kv;
251 }
252 }
253 } else {
254 packed.as_uint8_ptr += (nr - nr_block_size) * kr;
255 }
256 }
257 }
258 }
259 }
260
pytorch_pack_q8deconv_wdq(size_t n,size_t ks,size_t kc,uint32_t nr,uint32_t kr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)261 static inline void pytorch_pack_q8deconv_wdq(
262 size_t n,
263 size_t ks,
264 size_t kc,
265 uint32_t nr,
266 uint32_t kr,
267 uint8_t izp,
268 uint8_t kzp,
269 const uint8_t* k,
270 const int32_t* b,
271 void* packed_w) {
272 const int32_t boff = (int32_t)ks * (int32_t)kc * (int32_t)izp * (int32_t)kzp;
273 for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) {
274 const size_t nr_block_size = min(n - nr_block_start, nr);
275 int32_t* packed_b = (int32_t*)packed_w;
276 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
277 nr_block_offset++) {
278 *((int32_t*)packed_w) = b ? b[nr_block_start + nr_block_offset] + boff : 0.0f;
279 packed_w = (void*)((uintptr_t)packed_w + sizeof(int32_t));
280 }
281 packed_w =
282 (void*)((uintptr_t)packed_w + (nr - nr_block_size) * sizeof(int32_t));
283 for (size_t ki = 0; ki < ks; ki++) {
284 for (size_t kr_block_start = 0; kr_block_start < kc;
285 kr_block_start += kr) {
286 const size_t kr_block_size = min(kc - kr_block_start, kr);
287 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
288 nr_block_offset++) {
289 int32_t ksum = 0;
290 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
291 kr_block_offset++) {
292 const uint8_t kv =
293 k[((kr_block_start + kr_block_offset) * ks + ki) * n +
294 (nr_block_start + nr_block_offset)];
295 ksum += (int32_t)kv;
296 *((uint8_t*)packed_w) = kv;
297 packed_w = (void*)((uintptr_t)packed_w + sizeof(uint8_t));
298 }
299 packed_b[nr_block_offset] -= ksum * (int32_t)izp;
300 packed_w =
301 (void*)((uintptr_t)packed_w + (kr - kr_block_size) * sizeof(uint8_t));
302 }
303 packed_w =
304 (void*)((uintptr_t)packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
305 }
306 }
307 }
308 }
309
pytorch_pack_q8deconv_wrq(const size_t n,const size_t ks,const size_t kc,const uint32_t nr,const uint32_t kr,const uint8_t * const k,const int32_t * const b,const uint8_t * const kzp,void * const packed_w)310 static inline void pytorch_pack_q8deconv_wrq(
311 const size_t n,
312 const size_t ks,
313 const size_t kc,
314 const uint32_t nr,
315 const uint32_t kr,
316 const uint8_t* const k,
317 const int32_t* const b,
318 const uint8_t* const kzp,
319 void* const packed_w) {
320 union {
321 void* const as_void_ptr;
322 uint8_t* as_uint8_ptr;
323 int32_t* as_int32_ptr;
324 } packed = {packed_w};
325
326 for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) {
327 const size_t nr_block_size = min(n - nr_block_start, nr);
328 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
329 nr_block_offset++) {
330 *(packed.as_int32_ptr++) = b ? b[nr_block_start + nr_block_offset] : 0.0f;
331 }
332 packed.as_int32_ptr += (nr - nr_block_size);
333 for (size_t ki = 0; ki < ks; ki++) {
334 for (size_t kr_block_start = 0; kr_block_start < kc;
335 kr_block_start += kr) {
336 const size_t kr_block_size = min(kc - kr_block_start, kr);
337 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
338 nr_block_offset++) {
339 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
340 kr_block_offset++) {
341 const uint8_t kv =
342 k[((kr_block_start + kr_block_offset) * ks + ki) * n +
343 (nr_block_start + nr_block_offset)];
344 *(packed.as_uint8_ptr++) = kv;
345 }
346 // Weights need to be prepacked with the zero points, in their tail space
347 // where packed blocks are not multiple of input sizes
348 // e.g for ukernels with kr=2 and k is 3 then the second block must be
349 // padded with zero point. This is because when subtracting with zero point
350 // we just get zero for the padded value, which is what we want.
351 if (kzp != 0) {
352 for (size_t kr_block_offset = 0; kr_block_offset < (kr - kr_block_size);
353 kr_block_offset++) {
354 const uint8_t kv =
355 kzp[(nr_block_start + nr_block_offset)];
356 *(packed.as_uint8_ptr++) = kv;
357 }
358 } else {
359 packed.as_uint8_ptr += (kr - kr_block_size);
360 }
361 }
362 if (kzp != 0) {
363 // This part fills the packed wights with zero points for output channels
364 // when they are not divisible by nr blocking parameter.
365 // In that case
366 for (size_t nr_block_offset = 0; nr_block_offset < (nr - nr_block_size);
367 nr_block_offset++) {
368 for (size_t kr_block_offset = 0; kr_block_offset < kr;
369 kr_block_offset++) {
370 const uint8_t kv =
371 kzp[(nr_block_start + nr_block_size + nr_block_offset)];
372 *(packed.as_uint8_ptr++) = kv;
373 }
374 }
375 } else {
376 packed.as_uint8_ptr += (nr - nr_block_size) * kr;
377 }
378 }
379 }
380 }
381 }
382
pytorch_pack_q8dw_wdq(size_t h,size_t w,size_t c,size_t cr,uint8_t izp,uint8_t * kzp,const uint8_t * k,const int32_t * b,void * packed_w)383 static inline void pytorch_pack_q8dw_wdq(
384 size_t h,
385 size_t w,
386 size_t c,
387 size_t cr,
388 uint8_t izp,
389 uint8_t* kzp,
390 const uint8_t* k,
391 const int32_t* b,
392 void* packed_w) {
393 const int32_t boff = (int32_t)h * (int32_t)w * (int32_t)izp;
394 for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
395 const size_t cr_block_size = min(c - cr_block_start, cr);
396 int32_t* packed_b = (int32_t*)packed_w;
397 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
398 cr_block_offset++) {
399 *((int32_t*)packed_w) =
400 b ?
401 b[cr_block_start + cr_block_offset] +
402 boff * kzp[cr_block_start + cr_block_offset] : 0.0f;
403 packed_w = (void*)((uintptr_t)packed_w + sizeof(int32_t));
404 }
405 packed_w =
406 (void*)((uintptr_t)packed_w + (cr - cr_block_size) * sizeof(int32_t));
407 for (size_t x = 0; x < w; x++) {
408 for (size_t y = 0; y < h; y++) {
409 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
410 cr_block_offset++) {
411 const uint8_t kv =
412 k[((cr_block_start + cr_block_offset) * h + y) * w + x];
413 packed_b[cr_block_offset] -= (int32_t)kv * (int32_t)izp;
414 *((uint8_t*)packed_w) = kv;
415 packed_w = (void*)((uintptr_t)packed_w + sizeof(uint8_t));
416 }
417 packed_w =
418 (void*)((uintptr_t)packed_w + (cr - cr_block_size) * sizeof(uint8_t));
419 }
420 }
421 }
422 }
423
pytorch_pack_q8dw_wrq(const size_t h,const size_t w,const size_t c,const size_t cr,const uint8_t * const k,const int32_t * const b,void * const packed_w)424 static inline void pytorch_pack_q8dw_wrq(
425 const size_t h,
426 const size_t w,
427 const size_t c,
428 const size_t cr,
429 const uint8_t* const k,
430 const int32_t* const b,
431 void* const packed_w) {
432 union {
433 void* const as_void_ptr;
434 uint8_t* as_uint8_ptr;
435 int32_t* as_int32_ptr;
436 } packed = {packed_w};
437
438 for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
439 const size_t cr_block_size = min(c - cr_block_start, cr);
440 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
441 cr_block_offset++) {
442 *(packed.as_int32_ptr++) = b ? b[cr_block_start + cr_block_offset] : 0.0f;
443 }
444 packed.as_int32_ptr += (cr - cr_block_size);
445 for (size_t x = 0; x < w; x++) {
446 for (size_t y = 0; y < h; y++) {
447 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
448 cr_block_offset++) {
449 const uint8_t kv =
450 k[((cr_block_start + cr_block_offset) * h + y) * w + x];
451 *(packed.as_uint8_ptr++) = kv;
452 }
453 packed.as_uint8_ptr += (cr - cr_block_size);
454 }
455 }
456 }
457 }
458
pytorch_pack_q8dw_3d_w_dilation(size_t d,size_t h,size_t w,size_t c,size_t cr,size_t z_start,size_t z_end,size_t y_start,size_t y_end,size_t x_start,size_t x_end,const uint8_t * k,const int32_t * b,void * packed_w,bool pytorch_pack_b)459 static inline void pytorch_pack_q8dw_3d_w_dilation(
460 size_t d,
461 size_t h,
462 size_t w,
463 size_t c,
464 size_t cr,
465 size_t z_start,
466 size_t z_end,
467 size_t y_start,
468 size_t y_end,
469 size_t x_start,
470 size_t x_end,
471 const uint8_t* k,
472 const int32_t* b,
473 void* packed_w,
474 bool pytorch_pack_b) {
475 for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
476 const size_t cr_block_size = min(c - cr_block_start, cr);
477 if (pytorch_pack_b) {
478 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
479 cr_block_offset++) {
480 *((int32_t*)packed_w) = b ? b[cr_block_start + cr_block_offset] : 0.0f;
481 packed_w = (void*)((int32_t*)packed_w + 1);
482 }
483 packed_w =
484 (void*)((int32_t*)packed_w + (cr - cr_block_size));
485 }
486 for (size_t x = x_start; x < x_end; x++) {
487 for (size_t y = y_start; y < y_end; y++) {
488 for (size_t z = z_start; z < z_end; z++) {
489 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
490 cr_block_offset++) {
491 *((uint8_t*)packed_w) =
492 k[(((cr_block_start + cr_block_offset) * d + z) * h + y) * w + x];
493 packed_w = (void*)((uint8_t*)packed_w + 1);
494 }
495 packed_w =
496 (void*)((uint8_t*)packed_w + (cr - cr_block_size));
497 }
498 }
499 }
500 }
501 }
502
pytorch_pack_q8dw_2d_w_dilation(size_t h,size_t w,size_t c,size_t cr,size_t y_start,size_t y_end,size_t x_start,size_t x_end,const uint8_t * k,const int32_t * b,void * packed_w,bool pytorch_pack_b)503 static inline void pytorch_pack_q8dw_2d_w_dilation(
504 size_t h,
505 size_t w,
506 size_t c,
507 size_t cr,
508 size_t y_start,
509 size_t y_end,
510 size_t x_start,
511 size_t x_end,
512 const uint8_t* k,
513 const int32_t* b,
514 void* packed_w,
515 bool pytorch_pack_b) {
516 pytorch_pack_q8dw_3d_w_dilation(
517 1, /* d */
518 h,
519 w,
520 c,
521 cr,
522 0, /* z_start */
523 1, /* z_end */
524 y_start,
525 y_end,
526 x_start,
527 x_end,
528 k,
529 b,
530 packed_w,
531 pytorch_pack_b);
532 }
533
pytorch_pack_swizzle_q8gemm_bdq(size_t n,size_t kc,uint32_t nr,uint32_t kr,uint32_t sr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)534 static inline void pytorch_pack_swizzle_q8gemm_bdq(
535 size_t n,
536 size_t kc,
537 uint32_t nr,
538 uint32_t kr,
539 uint32_t sr,
540 uint8_t izp,
541 uint8_t kzp,
542 const uint8_t* k,
543 const int32_t* b,
544 void* packed_w) {
545 const int32_t boff = (int32_t)kc * (int32_t)izp * (int32_t)kzp;
546 for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) {
547 const size_t nr_block_size = min(n - nr_block_start, nr);
548 int32_t* packed_b = (int32_t*)packed_w;
549 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
550 nr_block_offset++) {
551 *((int32_t*)packed_w) = b ? b[nr_block_start + nr_block_offset] + boff : 0.0f;
552 packed_w = (void*)((uintptr_t)packed_w + sizeof(int32_t));
553 }
554 packed_w =
555 (void*)((uintptr_t)packed_w + (nr - nr_block_size) * sizeof(int32_t));
556
557 for (size_t kr_block_start = 0; kr_block_start < (kc & -sr);
558 kr_block_start += kr) {
559 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
560 nr_block_offset++) {
561 for (size_t kr_block_offset = 0; kr_block_offset < kr;
562 kr_block_offset++) {
563 const uint8_t kv =
564 k[(nr_block_start + nr_block_offset) * kc +
565 (kr_block_start & -sr) +
566 ((kr_block_start + nr_block_offset * kr) & (sr - 1)) +
567 kr_block_offset];
568 packed_b[nr_block_offset] -= (int32_t)kv * (int32_t)izp;
569 *((uint8_t*)packed_w) = kv;
570 packed_w = (void*)((uintptr_t)packed_w + sizeof(uint8_t));
571 }
572 }
573 packed_w =
574 (void*)((uintptr_t)packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
575 }
576
577 for (size_t kr_block_start = (kc & -sr); kr_block_start < kc;
578 kr_block_start += kr) {
579 const size_t kr_block_size = min(kc - kr_block_start, kr);
580 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
581 nr_block_offset++) {
582 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
583 kr_block_offset++) {
584 const uint8_t kv =
585 k[(nr_block_start + nr_block_offset) * kc +
586 (kr_block_start + kr_block_offset)];
587 packed_b[nr_block_offset] -= (int32_t)kv * (int32_t)izp;
588 *((uint8_t*)packed_w) = kv;
589 packed_w = (void*)((uintptr_t)packed_w + sizeof(uint8_t));
590 }
591 packed_w =
592 (void*)((uintptr_t)packed_w + (kr - kr_block_size) * sizeof(uint8_t));
593 }
594 packed_w =
595 (void*)((uintptr_t)packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
596 }
597 }
598 }
599
pytorch_pack_swizzle_q8gemm_brq(const size_t n,const size_t kc,const uint32_t nr,const uint32_t kr,const uint32_t sr,const uint8_t * const k,const int32_t * const b,void * const packed_w)600 static inline void pytorch_pack_swizzle_q8gemm_brq(
601 const size_t n,
602 const size_t kc,
603 const uint32_t nr,
604 const uint32_t kr,
605 const uint32_t sr,
606 const uint8_t* const k,
607 const int32_t* const b,
608 void* const packed_w) {
609 union {
610 void* const as_void_ptr;
611 uint8_t* as_uint8_ptr;
612 int32_t* as_int32_ptr;
613 } packed = {packed_w};
614
615 for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) {
616 const size_t nr_block_size = min(n - nr_block_start, nr);
617 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
618 nr_block_offset++) {
619 *(packed.as_int32_ptr++) = b ? b[nr_block_start + nr_block_offset] : 0.0f;
620 }
621
622 packed.as_int32_ptr += (nr - nr_block_size);
623
624 for (size_t kr_block_start = 0; kr_block_start < (kc & -sr);
625 kr_block_start += kr) {
626 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
627 nr_block_offset++) {
628 for (size_t kr_block_offset = 0; kr_block_offset < kr;
629 kr_block_offset++) {
630 const uint8_t kv =
631 k[(nr_block_start + nr_block_offset) * kc +
632 (kr_block_start & -sr) +
633 ((kr_block_start + nr_block_offset * kr) & (sr - 1)) +
634 kr_block_offset];
635 *(packed.as_uint8_ptr++) = kv;
636 }
637 }
638 packed.as_uint8_ptr += (nr - nr_block_size) * kr;
639 }
640
641 for (size_t kr_block_start = (kc & -sr); kr_block_start < kc;
642 kr_block_start += kr) {
643 const size_t kr_block_size = min(kc - kr_block_start, kr);
644 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
645 nr_block_offset++) {
646 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
647 kr_block_offset++) {
648 const uint8_t kv =
649 k[(nr_block_start + nr_block_offset) * kc +
650 (kr_block_start + kr_block_offset)];
651 *(packed.as_uint8_ptr++) = kv;
652 }
653 packed.as_uint8_ptr += (kr - kr_block_size);
654 }
655 packed.as_uint8_ptr += (nr - nr_block_size) * kr;
656 }
657 }
658 }
659
pytorch_pack_hgemm_w(size_t nc,size_t kc,size_t nr,size_t kr,const uint16_t * k,const uint16_t * b,uint16_t * packed_w)660 static inline void pytorch_pack_hgemm_w(
661 size_t nc,
662 size_t kc,
663 size_t nr,
664 size_t kr,
665 const uint16_t* k,
666 const uint16_t* b,
667 uint16_t* packed_w) {
668 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
669 const size_t nr_block_size = min(nc - nr_block_start, nr);
670 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
671 nr_block_offset++) {
672 *packed_w++ = b ? b[nr_block_start + nr_block_offset] : 0.0f;
673 }
674 packed_w += nr - nr_block_size;
675 for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
676 const size_t kr_block_size = min(kc - kr_block_start, kr);
677 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
678 nr_block_offset++) {
679 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
680 kr_block_offset++) {
681 *packed_w++ =
682 k[(nr_block_start + nr_block_offset) * kc +
683 (kr_block_start + kr_block_offset)];
684 }
685 packed_w += kr - kr_block_size;
686 }
687 packed_w += (nr - nr_block_size) * kr;
688 }
689 }
690 }
691
pytorch_pack_sgemm_w(size_t nc,size_t kc,size_t nr,size_t kr,const float * k,const float * b,float * packed_w)692 static inline void pytorch_pack_sgemm_w(
693 size_t nc,
694 size_t kc,
695 size_t nr,
696 size_t kr,
697 const float* k,
698 const float* b,
699 float* packed_w) {
700 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
701 const size_t nr_block_size = min(nc - nr_block_start, nr);
702 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
703 nr_block_offset++) {
704 *packed_w++ = b ? b[nr_block_start + nr_block_offset] : 0.0f;
705 }
706 packed_w += nr - nr_block_size;
707 for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
708 const size_t kr_block_size = min(kc - kr_block_start, kr);
709 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
710 nr_block_offset++) {
711 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
712 kr_block_offset++) {
713 *packed_w++ =
714 k[(nr_block_start + nr_block_offset) * kc +
715 (kr_block_start + kr_block_offset)];
716 }
717 packed_w += kr - kr_block_size;
718 }
719 packed_w += (nr - nr_block_size) * kr;
720 }
721 }
722 }
723
pytorch_pack_sconv_w(size_t n,size_t ks,size_t kc,size_t nr,size_t kr,const float * k,const float * b,float * packed_w)724 static inline void pytorch_pack_sconv_w(
725 size_t n,
726 size_t ks,
727 size_t kc,
728 size_t nr,
729 size_t kr,
730 const float* k,
731 const float* b,
732 float* packed_w) {
733 for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) {
734 const size_t nr_block_size = min(n - nr_block_start, nr);
735 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
736 nr_block_offset++) {
737 *packed_w++ = b ? b[nr_block_start + nr_block_offset] : 0.0f;
738 }
739 packed_w += nr - nr_block_size;
740 for (size_t ki = 0; ki < ks; ki++) {
741 for (size_t kr_block_start = 0; kr_block_start < kc;
742 kr_block_start += kr) {
743 const size_t kr_block_size = min(kc - kr_block_start, kr);
744 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
745 nr_block_offset++) {
746 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
747 kr_block_offset++) {
748 *packed_w++ =
749 k[((nr_block_start + nr_block_offset) * ks + ki) * kc +
750 (kr_block_start + kr_block_offset)];
751 }
752 packed_w += kr - kr_block_size;
753 }
754 packed_w += (nr - nr_block_size) * kr;
755 }
756 }
757 }
758 }
759
760 #if PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
761
762 #define pytorch_pack_q8gemm_w pytorch_pack_q8gemm_wrq
763 #define pytorch_pack_q8conv_w pytorch_pack_q8conv_wrq
764 #define pytorch_pack_q8deconv_w pytorch_pack_q8deconv_wrq
765 #define pytorch_pack_q8dw_w pytorch_pack_q8dw_wrq
766 #define pytorch_pack_swizzle_q8gemm_b pytorch_pack_swizzle_q8gemm_brq
767
768 #else
769
770 #define pytorch_pack_q8gemm_w pytorch_pack_q8gemm_wdq
771 #define pytorch_pack_q8conv_w pytorch_pack_q8conv_wdq
772 #define pytorch_pack_q8deconv_w pytorch_pack_q8deconv_wdq
773 #define pytorch_pack_q8dw_w pytorch_pack_q8dw_wdq
774 #define pytorch_pack_swizzle_q8gemm_b pytorch_pack_swizzle_q8gemm_bdq
775
776 #endif
777