xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/pack.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 #include <qnnpack/math.h>
11 #include <stdint.h>
12 
13 // Legend:
14 //  dq: Design-time Quantization
15 //  rq: Run-time Quantization
16 
pytorch_pack_q8gemm_wdq(size_t nc,size_t kc,uint32_t nr,uint32_t np,uint32_t kr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)17 static inline void pytorch_pack_q8gemm_wdq(
18     size_t nc,
19     size_t kc,
20     uint32_t nr,
21     uint32_t np,
22     uint32_t kr,
23     uint8_t izp,
24     uint8_t kzp,
25     const uint8_t* k,
26     const int32_t* b,
27     void* packed_w) {
28   const int32_t boff = (int32_t)kc * (int32_t)izp * (int32_t)kzp;
29   for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
30     const size_t nr_block_size = min(nc - nr_block_start, nr);
31     int32_t* packed_b = (int32_t*)packed_w;
32     for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
33          nr_block_offset++) {
34       *((int32_t*)packed_w) = b ? b[nr_block_start + nr_block_offset] + boff : 0.0f;
35       packed_w = (void*)((uintptr_t)packed_w + sizeof(int32_t));
36     }
37     packed_w =
38         (void*)((uintptr_t)packed_w + (nr - nr_block_size) * sizeof(int32_t));
39     for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
40       const size_t kr_block_size = min(kc - kr_block_start, kr);
41       for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
42            nr_block_offset++) {
43         int32_t ksum = 0;
44         for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
45              kr_block_offset++) {
46           const uint8_t kv =
47               k[(nr_block_start + nr_block_offset) * kc +
48                 (kr_block_start + kr_block_offset)];
49           ksum += (int32_t)kv;
50           *((uint8_t*)packed_w) = kv;
51           packed_w = (void*)((uintptr_t)packed_w + sizeof(uint8_t));
52         }
53         packed_b[nr_block_offset] -= ksum * (int32_t)izp;
54         packed_w =
55             (void*)((uintptr_t)packed_w + (kr - kr_block_size) * sizeof(uint8_t));
56       }
57       packed_w =
58           (void*)((uintptr_t)packed_w + ((nr - nr_block_size) & (np - 1)) * kr * sizeof(uint8_t));
59     }
60   }
61 }
62 
63 // NB: We use the same packing function for both dynamic quantization
64 // and runtime quantization for linear.
65 // This means that dynamic mode will suffer some perf because of the branching
66 // introduced due to `if(kzp!=0)` however, that should not be too significant.
pytorch_pack_q8gemm_wrq(const size_t nc,const size_t kc,const uint32_t nr,const uint32_t np,const uint32_t kr,const uint8_t * const k,const int32_t * const b,const uint8_t * const kzp,void * const packed_w)67 static inline void pytorch_pack_q8gemm_wrq(
68     const size_t nc,
69     const size_t kc,
70     const uint32_t nr,
71     const uint32_t np,
72     const uint32_t kr,
73     const uint8_t* const k,
74     const int32_t* const b,
75     const uint8_t* const kzp,
76     void* const packed_w) {
77   union {
78     void* const as_void_ptr;
79     uint8_t* as_uint8_ptr;
80     int32_t* as_int32_ptr;
81   } packed = {packed_w};
82 
83   for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
84     const size_t nr_block_size = min(nc - nr_block_start, nr);
85     for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
86          nr_block_offset++) {
87       *(packed.as_int32_ptr++) = b ? b[nr_block_start + nr_block_offset] : 0;
88     }
89     packed.as_int32_ptr += (nr - nr_block_size);
90     for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
91       const size_t kr_block_size = min(kc - kr_block_start, kr);
92       for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
93            nr_block_offset++) {
94         for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
95              kr_block_offset++) {
96           const uint8_t kv =
97               k[(nr_block_start + nr_block_offset) * kc +
98                 (kr_block_start + kr_block_offset)];
99           *(packed.as_uint8_ptr++) = kv;
100         }
101         // Weights need to be prepacked with the zero points, in their tail space
102         // where packed blocks are not multiple of input sizes
103         // e.g for ukernels with kr=2 and k is 3 then the second block must be
104         // padded with zero point. This is because when subtracting with zero point
105         // we just get zero for the padded value, which is what we want.
106         if (kzp != 0) {
107           for (size_t kr_block_offset = 0; kr_block_offset < (kr - kr_block_size);
108                kr_block_offset++) {
109             const uint8_t kv =
110                 kzp[(nr_block_start + nr_block_offset)];
111             *(packed.as_uint8_ptr++) = kv;
112           }
113         } else {
114           packed.as_uint8_ptr += (kr - kr_block_size);
115         }
116       }
117       if (kzp != 0) {
118         // This part fills the packed weights with zero points for output channels
119         // when they are not divisible by nr blocking parameter.
120         // This is needed because in some kernels, sse2 ones, it relies on this
121         // to produce zero as a result of subtracting zero point from weight value.
122         size_t remaining_nr_blocks = ((nr - nr_block_size) & (np - 1));
123         for (size_t nr_block_offset = 0; nr_block_offset < remaining_nr_blocks;
124              nr_block_offset++) {
125           for (size_t kr_block_offset = 0; kr_block_offset < kr;
126                kr_block_offset++) {
127             const uint8_t kv =
128                 kzp[(nr_block_start + nr_block_size + nr_block_offset)];
129             *(packed.as_uint8_ptr++) = kv;
130           }
131         }
132       } else {
133         packed.as_uint8_ptr += ((nr - nr_block_size) & (np - 1)) * kr;
134       }
135     }
136   }
137 }
138 
pytorch_pack_q8conv_wdq(size_t n,size_t ks,size_t kc,uint32_t nr,uint32_t kr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)139 static inline void pytorch_pack_q8conv_wdq(
140     size_t n,
141     size_t ks,
142     size_t kc,
143     uint32_t nr,
144     uint32_t kr,
145     uint8_t izp,
146     uint8_t kzp,
147     const uint8_t* k,
148     const int32_t* b,
149     void* packed_w) {
150   const int32_t boff = (int32_t)ks * (int32_t)kc * (int32_t)izp * (int32_t)kzp;
151   for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) {
152     const size_t nr_block_size = min(n - nr_block_start, nr);
153     int32_t* packed_b = (int32_t*)packed_w;
154     for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
155          nr_block_offset++) {
156       *((int32_t*)packed_w) = b ? b[nr_block_start + nr_block_offset] + boff : 0.0f;
157       packed_w = (void*)((uintptr_t)packed_w + sizeof(int32_t));
158     }
159     packed_w =
160         (void*)((uintptr_t)packed_w + (nr - nr_block_size) * sizeof(int32_t));
161     for (size_t ki = 0; ki < ks; ki++) {
162       for (size_t kr_block_start = 0; kr_block_start < kc;
163            kr_block_start += kr) {
164         const size_t kr_block_size = min(kc - kr_block_start, kr);
165         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
166              nr_block_offset++) {
167           int32_t ksum = 0;
168           for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
169                kr_block_offset++) {
170             const uint8_t kv =
171                 k[((nr_block_start + nr_block_offset) * ks + ki) * kc +
172                   (kr_block_start + kr_block_offset)];
173             ksum += (int32_t)kv;
174             *((uint8_t*)packed_w) = kv;
175             packed_w = (void*)((uintptr_t)packed_w + sizeof(uint8_t));
176           }
177           packed_b[nr_block_offset] -= ksum * (int32_t)izp;
178           packed_w =
179               (void*)((uintptr_t)packed_w + (kr - kr_block_size) * sizeof(uint8_t));
180         }
181         packed_w =
182             (void*)((uintptr_t)packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
183       }
184     }
185   }
186 }
187 
pytorch_pack_q8conv_wrq(const size_t n,const size_t ks,const size_t kc,const uint32_t nr,const uint32_t kr,const uint8_t * const k,const int32_t * const b,const uint8_t * const kzp,void * const packed_w)188 static inline void pytorch_pack_q8conv_wrq(
189     const size_t n,
190     const size_t ks,
191     const size_t kc,
192     const uint32_t nr,
193     const uint32_t kr,
194     const uint8_t* const k,
195     const int32_t* const b,
196     const uint8_t* const kzp,
197     void* const packed_w) {
198   union {
199     void* const as_void_ptr;
200     uint8_t* as_uint8_ptr;
201     int32_t* as_int32_ptr;
202   } packed = {packed_w};
203 
204   for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) {
205     const size_t nr_block_size = min(n - nr_block_start, nr);
206     for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
207          nr_block_offset++) {
208       *(packed.as_int32_ptr++) = b ? b[nr_block_start + nr_block_offset] : 0.0f;
209     }
210     packed.as_int32_ptr += (nr - nr_block_size);
211     for (size_t ki = 0; ki < ks; ki++) {
212       for (size_t kr_block_start = 0; kr_block_start < kc;
213            kr_block_start += kr) {
214         const size_t kr_block_size = min(kc - kr_block_start, kr);
215         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
216              nr_block_offset++) {
217           for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
218                kr_block_offset++) {
219             const uint8_t kv =
220                 k[((nr_block_start + nr_block_offset) * ks + ki) * kc +
221                   (kr_block_start + kr_block_offset)];
222             *(packed.as_uint8_ptr++) = kv;
223           }
224           // Weights need to be prepacked with the zero points, in their tail space
225           // where packed blocks are not multiple of input sizes
226           // e.g for ukernels with kr=2 and k is 3 then the second block must be
227           // padded with zero point. This is because when subtracting with zero point
228           // we just get zero for the padded value, which is what we want.
229           if (kzp != 0) {
230             for (size_t kr_block_offset = 0; kr_block_offset < (kr - kr_block_size);
231                  kr_block_offset++) {
232               const uint8_t kv =
233                   kzp[(nr_block_start + nr_block_offset)];
234               *(packed.as_uint8_ptr++) = kv;
235             }
236           } else {
237             packed.as_uint8_ptr += (kr - kr_block_size);
238           }
239         }
240         if (kzp != 0) {
241           // This part fills the packed wights with zero points for output channels
242           // when they are not divisible by nr blocking parameter.
243           // In that case
244           for (size_t nr_block_offset = 0; nr_block_offset < (nr - nr_block_size);
245                nr_block_offset++) {
246             for (size_t kr_block_offset = 0; kr_block_offset < kr;
247                  kr_block_offset++) {
248               const uint8_t kv =
249                   kzp[(nr_block_start + nr_block_size + nr_block_offset)];
250               *(packed.as_uint8_ptr++) = kv;
251             }
252           }
253         } else {
254           packed.as_uint8_ptr += (nr - nr_block_size) * kr;
255         }
256       }
257     }
258   }
259 }
260 
pytorch_pack_q8deconv_wdq(size_t n,size_t ks,size_t kc,uint32_t nr,uint32_t kr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)261 static inline void pytorch_pack_q8deconv_wdq(
262     size_t n,
263     size_t ks,
264     size_t kc,
265     uint32_t nr,
266     uint32_t kr,
267     uint8_t izp,
268     uint8_t kzp,
269     const uint8_t* k,
270     const int32_t* b,
271     void* packed_w) {
272   const int32_t boff = (int32_t)ks * (int32_t)kc * (int32_t)izp * (int32_t)kzp;
273   for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) {
274     const size_t nr_block_size = min(n - nr_block_start, nr);
275     int32_t* packed_b = (int32_t*)packed_w;
276     for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
277          nr_block_offset++) {
278       *((int32_t*)packed_w) = b ? b[nr_block_start + nr_block_offset] + boff : 0.0f;
279       packed_w = (void*)((uintptr_t)packed_w + sizeof(int32_t));
280     }
281     packed_w =
282         (void*)((uintptr_t)packed_w + (nr - nr_block_size) * sizeof(int32_t));
283     for (size_t ki = 0; ki < ks; ki++) {
284       for (size_t kr_block_start = 0; kr_block_start < kc;
285            kr_block_start += kr) {
286         const size_t kr_block_size = min(kc - kr_block_start, kr);
287         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
288              nr_block_offset++) {
289           int32_t ksum = 0;
290           for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
291                kr_block_offset++) {
292             const uint8_t kv =
293                 k[((kr_block_start + kr_block_offset) * ks + ki) * n +
294                   (nr_block_start + nr_block_offset)];
295             ksum += (int32_t)kv;
296             *((uint8_t*)packed_w) = kv;
297             packed_w = (void*)((uintptr_t)packed_w + sizeof(uint8_t));
298           }
299           packed_b[nr_block_offset] -= ksum * (int32_t)izp;
300           packed_w =
301               (void*)((uintptr_t)packed_w + (kr - kr_block_size) * sizeof(uint8_t));
302         }
303         packed_w =
304             (void*)((uintptr_t)packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
305       }
306     }
307   }
308 }
309 
pytorch_pack_q8deconv_wrq(const size_t n,const size_t ks,const size_t kc,const uint32_t nr,const uint32_t kr,const uint8_t * const k,const int32_t * const b,const uint8_t * const kzp,void * const packed_w)310 static inline void pytorch_pack_q8deconv_wrq(
311     const size_t n,
312     const size_t ks,
313     const size_t kc,
314     const uint32_t nr,
315     const uint32_t kr,
316     const uint8_t* const k,
317     const int32_t* const b,
318     const uint8_t* const kzp,
319     void* const packed_w) {
320   union {
321     void* const as_void_ptr;
322     uint8_t* as_uint8_ptr;
323     int32_t* as_int32_ptr;
324   } packed = {packed_w};
325 
326   for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) {
327     const size_t nr_block_size = min(n - nr_block_start, nr);
328     for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
329          nr_block_offset++) {
330       *(packed.as_int32_ptr++) = b ? b[nr_block_start + nr_block_offset] : 0.0f;
331     }
332     packed.as_int32_ptr += (nr - nr_block_size);
333     for (size_t ki = 0; ki < ks; ki++) {
334       for (size_t kr_block_start = 0; kr_block_start < kc;
335            kr_block_start += kr) {
336         const size_t kr_block_size = min(kc - kr_block_start, kr);
337         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
338              nr_block_offset++) {
339           for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
340                kr_block_offset++) {
341             const uint8_t kv =
342                 k[((kr_block_start + kr_block_offset) * ks + ki) * n +
343                   (nr_block_start + nr_block_offset)];
344             *(packed.as_uint8_ptr++) = kv;
345           }
346           // Weights need to be prepacked with the zero points, in their tail space
347           // where packed blocks are not multiple of input sizes
348           // e.g for ukernels with kr=2 and k is 3 then the second block must be
349           // padded with zero point. This is because when subtracting with zero point
350           // we just get zero for the padded value, which is what we want.
351           if (kzp != 0) {
352             for (size_t kr_block_offset = 0; kr_block_offset < (kr - kr_block_size);
353                  kr_block_offset++) {
354               const uint8_t kv =
355                   kzp[(nr_block_start + nr_block_offset)];
356               *(packed.as_uint8_ptr++) = kv;
357             }
358           } else {
359             packed.as_uint8_ptr += (kr - kr_block_size);
360           }
361         }
362         if (kzp != 0) {
363           // This part fills the packed wights with zero points for output channels
364           // when they are not divisible by nr blocking parameter.
365           // In that case
366           for (size_t nr_block_offset = 0; nr_block_offset < (nr - nr_block_size);
367                nr_block_offset++) {
368             for (size_t kr_block_offset = 0; kr_block_offset < kr;
369                  kr_block_offset++) {
370               const uint8_t kv =
371                   kzp[(nr_block_start + nr_block_size + nr_block_offset)];
372               *(packed.as_uint8_ptr++) = kv;
373             }
374           }
375         } else {
376           packed.as_uint8_ptr += (nr - nr_block_size) * kr;
377         }
378       }
379     }
380   }
381 }
382 
pytorch_pack_q8dw_wdq(size_t h,size_t w,size_t c,size_t cr,uint8_t izp,uint8_t * kzp,const uint8_t * k,const int32_t * b,void * packed_w)383 static inline void pytorch_pack_q8dw_wdq(
384     size_t h,
385     size_t w,
386     size_t c,
387     size_t cr,
388     uint8_t izp,
389     uint8_t* kzp,
390     const uint8_t* k,
391     const int32_t* b,
392     void* packed_w) {
393   const int32_t boff = (int32_t)h * (int32_t)w * (int32_t)izp;
394   for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
395     const size_t cr_block_size = min(c - cr_block_start, cr);
396     int32_t* packed_b = (int32_t*)packed_w;
397     for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
398          cr_block_offset++) {
399       *((int32_t*)packed_w) =
400         b ?
401             b[cr_block_start + cr_block_offset] +
402             boff * kzp[cr_block_start + cr_block_offset] : 0.0f;
403       packed_w = (void*)((uintptr_t)packed_w + sizeof(int32_t));
404     }
405     packed_w =
406         (void*)((uintptr_t)packed_w + (cr - cr_block_size) * sizeof(int32_t));
407     for (size_t x = 0; x < w; x++) {
408       for (size_t y = 0; y < h; y++) {
409         for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
410              cr_block_offset++) {
411           const uint8_t kv =
412               k[((cr_block_start + cr_block_offset) * h + y) * w + x];
413           packed_b[cr_block_offset] -= (int32_t)kv * (int32_t)izp;
414           *((uint8_t*)packed_w) = kv;
415           packed_w = (void*)((uintptr_t)packed_w + sizeof(uint8_t));
416         }
417         packed_w =
418             (void*)((uintptr_t)packed_w + (cr - cr_block_size) * sizeof(uint8_t));
419       }
420     }
421   }
422 }
423 
pytorch_pack_q8dw_wrq(const size_t h,const size_t w,const size_t c,const size_t cr,const uint8_t * const k,const int32_t * const b,void * const packed_w)424 static inline void pytorch_pack_q8dw_wrq(
425     const size_t h,
426     const size_t w,
427     const size_t c,
428     const size_t cr,
429     const uint8_t* const k,
430     const int32_t* const b,
431     void* const packed_w) {
432   union {
433     void* const as_void_ptr;
434     uint8_t* as_uint8_ptr;
435     int32_t* as_int32_ptr;
436   } packed = {packed_w};
437 
438   for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
439     const size_t cr_block_size = min(c - cr_block_start, cr);
440     for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
441          cr_block_offset++) {
442       *(packed.as_int32_ptr++) = b ? b[cr_block_start + cr_block_offset] : 0.0f;
443     }
444     packed.as_int32_ptr += (cr - cr_block_size);
445     for (size_t x = 0; x < w; x++) {
446       for (size_t y = 0; y < h; y++) {
447         for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
448              cr_block_offset++) {
449           const uint8_t kv =
450               k[((cr_block_start + cr_block_offset) * h + y) * w + x];
451           *(packed.as_uint8_ptr++) = kv;
452         }
453         packed.as_uint8_ptr += (cr - cr_block_size);
454       }
455     }
456   }
457 }
458 
pytorch_pack_q8dw_3d_w_dilation(size_t d,size_t h,size_t w,size_t c,size_t cr,size_t z_start,size_t z_end,size_t y_start,size_t y_end,size_t x_start,size_t x_end,const uint8_t * k,const int32_t * b,void * packed_w,bool pytorch_pack_b)459 static inline void pytorch_pack_q8dw_3d_w_dilation(
460     size_t d,
461     size_t h,
462     size_t w,
463     size_t c,
464     size_t cr,
465     size_t z_start,
466     size_t z_end,
467     size_t y_start,
468     size_t y_end,
469     size_t x_start,
470     size_t x_end,
471     const uint8_t* k,
472     const int32_t* b,
473     void* packed_w,
474     bool pytorch_pack_b) {
475   for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
476     const size_t cr_block_size = min(c - cr_block_start, cr);
477     if (pytorch_pack_b) {
478       for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
479            cr_block_offset++) {
480         *((int32_t*)packed_w) = b ? b[cr_block_start + cr_block_offset] : 0.0f;
481         packed_w = (void*)((int32_t*)packed_w + 1);
482       }
483       packed_w =
484           (void*)((int32_t*)packed_w + (cr - cr_block_size));
485     }
486     for (size_t x = x_start; x < x_end; x++) {
487       for (size_t y = y_start; y < y_end; y++) {
488         for (size_t z = z_start; z < z_end; z++) {
489           for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
490               cr_block_offset++) {
491             *((uint8_t*)packed_w) =
492                 k[(((cr_block_start + cr_block_offset) * d + z) * h + y) * w + x];
493             packed_w = (void*)((uint8_t*)packed_w + 1);
494           }
495           packed_w =
496               (void*)((uint8_t*)packed_w + (cr - cr_block_size));
497         }
498       }
499     }
500   }
501 }
502 
pytorch_pack_q8dw_2d_w_dilation(size_t h,size_t w,size_t c,size_t cr,size_t y_start,size_t y_end,size_t x_start,size_t x_end,const uint8_t * k,const int32_t * b,void * packed_w,bool pytorch_pack_b)503 static inline void pytorch_pack_q8dw_2d_w_dilation(
504     size_t h,
505     size_t w,
506     size_t c,
507     size_t cr,
508     size_t y_start,
509     size_t y_end,
510     size_t x_start,
511     size_t x_end,
512     const uint8_t* k,
513     const int32_t* b,
514     void* packed_w,
515     bool pytorch_pack_b) {
516   pytorch_pack_q8dw_3d_w_dilation(
517       1, /* d */
518       h,
519       w,
520       c,
521       cr,
522       0, /* z_start */
523       1, /* z_end */
524       y_start,
525       y_end,
526       x_start,
527       x_end,
528       k,
529       b,
530       packed_w,
531       pytorch_pack_b);
532 }
533 
pytorch_pack_swizzle_q8gemm_bdq(size_t n,size_t kc,uint32_t nr,uint32_t kr,uint32_t sr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)534 static inline void pytorch_pack_swizzle_q8gemm_bdq(
535     size_t n,
536     size_t kc,
537     uint32_t nr,
538     uint32_t kr,
539     uint32_t sr,
540     uint8_t izp,
541     uint8_t kzp,
542     const uint8_t* k,
543     const int32_t* b,
544     void* packed_w) {
545   const int32_t boff = (int32_t)kc * (int32_t)izp * (int32_t)kzp;
546   for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) {
547     const size_t nr_block_size = min(n - nr_block_start, nr);
548     int32_t* packed_b = (int32_t*)packed_w;
549     for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
550          nr_block_offset++) {
551       *((int32_t*)packed_w) = b ? b[nr_block_start + nr_block_offset] + boff : 0.0f;
552       packed_w = (void*)((uintptr_t)packed_w + sizeof(int32_t));
553     }
554     packed_w =
555         (void*)((uintptr_t)packed_w + (nr - nr_block_size) * sizeof(int32_t));
556 
557     for (size_t kr_block_start = 0; kr_block_start < (kc & -sr);
558          kr_block_start += kr) {
559       for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
560            nr_block_offset++) {
561         for (size_t kr_block_offset = 0; kr_block_offset < kr;
562              kr_block_offset++) {
563           const uint8_t kv =
564               k[(nr_block_start + nr_block_offset) * kc +
565                 (kr_block_start & -sr) +
566                 ((kr_block_start + nr_block_offset * kr) & (sr - 1)) +
567                 kr_block_offset];
568           packed_b[nr_block_offset] -= (int32_t)kv * (int32_t)izp;
569           *((uint8_t*)packed_w) = kv;
570           packed_w = (void*)((uintptr_t)packed_w + sizeof(uint8_t));
571         }
572       }
573       packed_w =
574           (void*)((uintptr_t)packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
575     }
576 
577     for (size_t kr_block_start = (kc & -sr); kr_block_start < kc;
578          kr_block_start += kr) {
579       const size_t kr_block_size = min(kc - kr_block_start, kr);
580       for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
581            nr_block_offset++) {
582         for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
583              kr_block_offset++) {
584           const uint8_t kv =
585               k[(nr_block_start + nr_block_offset) * kc +
586                 (kr_block_start + kr_block_offset)];
587           packed_b[nr_block_offset] -= (int32_t)kv * (int32_t)izp;
588           *((uint8_t*)packed_w) = kv;
589           packed_w = (void*)((uintptr_t)packed_w + sizeof(uint8_t));
590         }
591         packed_w =
592             (void*)((uintptr_t)packed_w + (kr - kr_block_size) * sizeof(uint8_t));
593       }
594       packed_w =
595           (void*)((uintptr_t)packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
596     }
597   }
598 }
599 
pytorch_pack_swizzle_q8gemm_brq(const size_t n,const size_t kc,const uint32_t nr,const uint32_t kr,const uint32_t sr,const uint8_t * const k,const int32_t * const b,void * const packed_w)600 static inline void pytorch_pack_swizzle_q8gemm_brq(
601     const size_t n,
602     const size_t kc,
603     const uint32_t nr,
604     const uint32_t kr,
605     const uint32_t sr,
606     const uint8_t* const k,
607     const int32_t* const b,
608     void* const packed_w) {
609   union {
610     void* const as_void_ptr;
611     uint8_t* as_uint8_ptr;
612     int32_t* as_int32_ptr;
613   } packed = {packed_w};
614 
615   for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) {
616     const size_t nr_block_size = min(n - nr_block_start, nr);
617     for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
618          nr_block_offset++) {
619       *(packed.as_int32_ptr++) = b ? b[nr_block_start + nr_block_offset] : 0.0f;
620     }
621 
622     packed.as_int32_ptr += (nr - nr_block_size);
623 
624     for (size_t kr_block_start = 0; kr_block_start < (kc & -sr);
625          kr_block_start += kr) {
626       for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
627            nr_block_offset++) {
628         for (size_t kr_block_offset = 0; kr_block_offset < kr;
629              kr_block_offset++) {
630           const uint8_t kv =
631               k[(nr_block_start + nr_block_offset) * kc +
632                 (kr_block_start & -sr) +
633                 ((kr_block_start + nr_block_offset * kr) & (sr - 1)) +
634                 kr_block_offset];
635           *(packed.as_uint8_ptr++) = kv;
636         }
637       }
638       packed.as_uint8_ptr += (nr - nr_block_size) * kr;
639     }
640 
641     for (size_t kr_block_start = (kc & -sr); kr_block_start < kc;
642          kr_block_start += kr) {
643       const size_t kr_block_size = min(kc - kr_block_start, kr);
644       for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
645            nr_block_offset++) {
646         for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
647              kr_block_offset++) {
648           const uint8_t kv =
649               k[(nr_block_start + nr_block_offset) * kc +
650                 (kr_block_start + kr_block_offset)];
651           *(packed.as_uint8_ptr++) = kv;
652         }
653         packed.as_uint8_ptr += (kr - kr_block_size);
654       }
655       packed.as_uint8_ptr += (nr - nr_block_size) * kr;
656     }
657   }
658 }
659 
pytorch_pack_hgemm_w(size_t nc,size_t kc,size_t nr,size_t kr,const uint16_t * k,const uint16_t * b,uint16_t * packed_w)660 static inline void pytorch_pack_hgemm_w(
661     size_t nc,
662     size_t kc,
663     size_t nr,
664     size_t kr,
665     const uint16_t* k,
666     const uint16_t* b,
667     uint16_t* packed_w) {
668   for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
669     const size_t nr_block_size = min(nc - nr_block_start, nr);
670     for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
671          nr_block_offset++) {
672       *packed_w++ = b ? b[nr_block_start + nr_block_offset] : 0.0f;
673     }
674     packed_w += nr - nr_block_size;
675     for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
676       const size_t kr_block_size = min(kc - kr_block_start, kr);
677       for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
678            nr_block_offset++) {
679         for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
680              kr_block_offset++) {
681           *packed_w++ =
682               k[(nr_block_start + nr_block_offset) * kc +
683                 (kr_block_start + kr_block_offset)];
684         }
685         packed_w += kr - kr_block_size;
686       }
687       packed_w += (nr - nr_block_size) * kr;
688     }
689   }
690 }
691 
pytorch_pack_sgemm_w(size_t nc,size_t kc,size_t nr,size_t kr,const float * k,const float * b,float * packed_w)692 static inline void pytorch_pack_sgemm_w(
693     size_t nc,
694     size_t kc,
695     size_t nr,
696     size_t kr,
697     const float* k,
698     const float* b,
699     float* packed_w) {
700   for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
701     const size_t nr_block_size = min(nc - nr_block_start, nr);
702     for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
703          nr_block_offset++) {
704       *packed_w++ = b ? b[nr_block_start + nr_block_offset] : 0.0f;
705     }
706     packed_w += nr - nr_block_size;
707     for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
708       const size_t kr_block_size = min(kc - kr_block_start, kr);
709       for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
710            nr_block_offset++) {
711         for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
712              kr_block_offset++) {
713           *packed_w++ =
714               k[(nr_block_start + nr_block_offset) * kc +
715                 (kr_block_start + kr_block_offset)];
716         }
717         packed_w += kr - kr_block_size;
718       }
719       packed_w += (nr - nr_block_size) * kr;
720     }
721   }
722 }
723 
pytorch_pack_sconv_w(size_t n,size_t ks,size_t kc,size_t nr,size_t kr,const float * k,const float * b,float * packed_w)724 static inline void pytorch_pack_sconv_w(
725     size_t n,
726     size_t ks,
727     size_t kc,
728     size_t nr,
729     size_t kr,
730     const float* k,
731     const float* b,
732     float* packed_w) {
733   for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) {
734     const size_t nr_block_size = min(n - nr_block_start, nr);
735     for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
736          nr_block_offset++) {
737       *packed_w++ = b ? b[nr_block_start + nr_block_offset] : 0.0f;
738     }
739     packed_w += nr - nr_block_size;
740     for (size_t ki = 0; ki < ks; ki++) {
741       for (size_t kr_block_start = 0; kr_block_start < kc;
742            kr_block_start += kr) {
743         const size_t kr_block_size = min(kc - kr_block_start, kr);
744         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
745              nr_block_offset++) {
746           for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
747                kr_block_offset++) {
748             *packed_w++ =
749                 k[((nr_block_start + nr_block_offset) * ks + ki) * kc +
750                   (kr_block_start + kr_block_offset)];
751           }
752           packed_w += kr - kr_block_size;
753         }
754         packed_w += (nr - nr_block_size) * kr;
755       }
756     }
757   }
758 }
759 
760 #if PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
761 
762 #define pytorch_pack_q8gemm_w pytorch_pack_q8gemm_wrq
763 #define pytorch_pack_q8conv_w pytorch_pack_q8conv_wrq
764 #define pytorch_pack_q8deconv_w pytorch_pack_q8deconv_wrq
765 #define pytorch_pack_q8dw_w pytorch_pack_q8dw_wrq
766 #define pytorch_pack_swizzle_q8gemm_b pytorch_pack_swizzle_q8gemm_brq
767 
768 #else
769 
770 #define pytorch_pack_q8gemm_w pytorch_pack_q8gemm_wdq
771 #define pytorch_pack_q8conv_w pytorch_pack_q8conv_wdq
772 #define pytorch_pack_q8deconv_w pytorch_pack_q8deconv_wdq
773 #define pytorch_pack_q8dw_w pytorch_pack_q8dw_wdq
774 #define pytorch_pack_swizzle_q8gemm_b pytorch_pack_swizzle_q8gemm_bdq
775 
776 #endif
777