xref: /aosp_15_r20/external/XNNPACK/src/xnnpack/pack.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <stdint.h>
12 #include <stddef.h>
13 
14 #include <xnnpack/common.h>
15 #include <xnnpack/operator.h>
16 
17 
18 #ifdef __cplusplus
19 extern "C" {
20 #endif
21 
22 
23 struct xnn_qu8_packing_params {
24   uint8_t input_zero_point;
25   uint8_t kernel_zero_point;
26 };
27 
28 struct xnn_qs8_packing_params {
29   int8_t input_zero_point;
30 };
31 
32 
33 typedef void (*xnn_pack_gemm_goi_w_function)(
34   size_t g,
35   size_t nc,
36   size_t kc,
37   size_t nr,
38   size_t kr,
39   size_t sr,
40   const void* k,
41   const void* b,
42   void* packed_w,
43   size_t extra_bytes,
44   const void* params);
45 
46 XNN_INTERNAL void xnn_pack_f32_gemm_goi_w(
47   size_t g,
48   size_t nc,
49   size_t kc,
50   size_t nr,
51   size_t kr,
52   size_t sr,
53   const float* k,
54   const float* b,
55   float* packed_w,
56   size_t extra_bytes,
57   const void* params);
58 
59 XNN_INTERNAL void xnn_pack_f16_gemm_goi_w(
60   size_t g,
61   size_t nc,
62   size_t kc,
63   size_t nr,
64   size_t kr,
65   size_t sr,
66   const uint16_t* k,
67   const uint16_t* b,
68   uint16_t* packed_w,
69   size_t extra_bytes,
70   const void* params);
71 
72 XNN_INTERNAL void xnn_pack_f32_to_f16_gemm_goi_w(
73   size_t g,
74   size_t nc,
75   size_t kc,
76   size_t nr,
77   size_t kr,
78   size_t sr,
79   const float* k,
80   const float* b,
81   uint16_t* packed_w,
82   size_t extra_bytes,
83   const void* params);
84 
85 XNN_INTERNAL void xnn_pack_qu8_gemm_goi_w(
86   size_t g,
87   size_t nc,
88   size_t kc,
89   size_t nr,
90   size_t kr,
91   size_t sr,
92   const uint8_t* k,
93   const int32_t* b,
94   void* packed_w,
95   size_t extra_bytes,
96   const struct xnn_qu8_packing_params* params);
97 
98 XNN_INTERNAL void xnn_pack_qs8_gemm_goi_w(
99   size_t g,
100   size_t nc,
101   size_t kc,
102   size_t nr,
103   size_t kr,
104   size_t sr,
105   const int8_t* k,
106   const int32_t* b,
107   void* packed_w,
108   size_t extra_bytes,
109   const struct xnn_qs8_packing_params* params);
110 
111 XNN_INTERNAL void xnn_pack_qs8_gemm_xw_goi_w(
112   size_t g,
113   size_t nc,
114   size_t kc,
115   size_t nr,
116   size_t kr,
117   size_t sr,
118   const int8_t* k,
119   const int32_t* b,
120   void* packed_w,
121   size_t extra_bytes,
122   const struct xnn_qs8_packing_params* params);
123 
124 
125 typedef void (*xnn_pack_gemm_io_w_function)(
126   size_t nc,
127   size_t kc,
128   size_t nr,
129   size_t kr,
130   size_t sr,
131   const void* k,
132   const void* b,
133   void* packed_w,
134   const void* params);
135 
136 XNN_INTERNAL void xnn_pack_f32_gemm_io_w(
137   size_t nc,
138   size_t kc,
139   size_t nr,
140   size_t kr,
141   size_t sr,
142   const float* k,
143   const float* b,
144   float* packed_w,
145   const void* params);
146 
147 XNN_INTERNAL void xnn_pack_f16_gemm_io_w(
148   size_t nc,
149   size_t kc,
150   size_t nr,
151   size_t kr,
152   size_t sr,
153   const uint16_t* k,
154   const uint16_t* b,
155   uint16_t* packed_w,
156   const void* params);
157 
158 XNN_INTERNAL void xnn_pack_f32_to_f16_gemm_io_w(
159   size_t nc,
160   size_t kc,
161   size_t nr,
162   size_t kr,
163   size_t sr,
164   const float* k,
165   const float* b,
166   uint16_t* packed_w,
167   const void* params);
168 
169 XNN_INTERNAL void xnn_pack_qu8_gemm_io_w(
170   size_t nc,
171   size_t kc,
172   size_t nr,
173   size_t kr,
174   size_t sr,
175   const uint8_t* k,
176   const int32_t* b,
177   void* packed_w,
178   const struct xnn_qu8_packing_params* params);
179 
180 XNN_INTERNAL void xnn_pack_qs8_gemm_io_w(
181   size_t nc,
182   size_t kc,
183   size_t nr,
184   size_t kr,
185   size_t sr,
186   const int8_t* k,
187   const int32_t* b,
188   void* packed_w,
189   const struct xnn_qs8_packing_params* params);
190 
191 
192 typedef void (*xnn_pack_conv_goki_w_function)(
193   size_t g,
194   size_t nc,
195   size_t ks,
196   size_t kc,
197   size_t nr,
198   size_t kr,
199   size_t sr,
200   const void* k,
201   const void* b,
202   void* packed_w,
203   size_t extra_bytes,
204   const void* params);
205 
206 XNN_INTERNAL void xnn_pack_f32_conv_goki_w(
207   size_t g,
208   size_t nc,
209   size_t ks,
210   size_t kc,
211   size_t nr,
212   size_t kr,
213   size_t sr,
214   const float* k,
215   const float* b,
216   float* packed_w,
217   size_t extra_bytes,
218   const void* params);
219 
220 XNN_INTERNAL void xnn_pack_f16_conv_goki_w(
221   size_t g,
222   size_t nc,
223   size_t ks,
224   size_t kc,
225   size_t nr,
226   size_t kr,
227   size_t sr,
228   const uint16_t* k,
229   const uint16_t* b,
230   uint16_t* packed_w,
231   size_t extra_bytes,
232   const void* params);
233 
234 XNN_INTERNAL void xnn_pack_f32_to_f16_conv_goki_w(
235   size_t g,
236   size_t nc,
237   size_t ks,
238   size_t kc,
239   size_t nr,
240   size_t kr,
241   size_t sr,
242   const float* k,
243   const float* b,
244   uint16_t* packed_w,
245   size_t extra_bytes,
246   const void* params);
247 
248 XNN_INTERNAL void xnn_pack_qu8_conv_goki_w(
249   size_t g,
250   size_t nc,
251   size_t ks,
252   size_t kc,
253   size_t nr,
254   size_t kr,
255   size_t sr,
256   const uint8_t* k,
257   const int32_t* b,
258   void* packed_w,
259   size_t extra_bytes,
260   const struct xnn_qu8_packing_params* params);
261 
262 XNN_INTERNAL void xnn_pack_qs8_conv_goki_w(
263   size_t g,
264   size_t nc,
265   size_t ks,
266   size_t kc,
267   size_t nr,
268   size_t kr,
269   size_t sr,
270   const int8_t* k,
271   const int32_t* b,
272   void* packed_w,
273   size_t extra_bytes,
274   const struct xnn_qs8_packing_params* params);
275 
276 
277 typedef void (*xnn_pack_conv_kgo_w_function)(
278   size_t g,
279   size_t nc,
280   size_t ks,
281   size_t nr,
282   size_t kr,
283   size_t sr,
284   const void* k,
285   const void* b,
286   void* packed_w,
287   size_t extra_bytes,
288   const void* params);
289 
290 XNN_INTERNAL void xnn_pack_f32_conv_kgo_w(
291   size_t g,
292   size_t nc,
293   size_t ks,
294   size_t nr,
295   size_t kr,
296   size_t sr,
297   const float* k,
298   const float* b,
299   float* packed_w,
300   size_t extra_bytes,
301   const void* params);
302 
303 XNN_INTERNAL void xnn_pack_f16_conv_kgo_w(
304   size_t g,
305   size_t nc,
306   size_t ks,
307   size_t nr,
308   size_t kr,
309   size_t sr,
310   const uint16_t* k,
311   const uint16_t* b,
312   uint16_t* packed_w,
313   size_t extra_bytes,
314   const void* params);
315 
316 XNN_INTERNAL void xnn_pack_f32_to_f16_conv_kgo_w(
317   size_t g,
318   size_t nc,
319   size_t ks,
320   size_t nr,
321   size_t kr,
322   size_t sr,
323   const float* k,
324   const float* b,
325   uint16_t* packed_w,
326   size_t extra_bytes,
327   const void* params);
328 
329 XNN_INTERNAL void xnn_pack_qu8_conv_kgo_w(
330   size_t g,
331   size_t nc,
332   size_t ks,
333   size_t nr,
334   size_t kr,
335   size_t sr,
336   const uint8_t* k,
337   const int32_t* b,
338   void* packed_w,
339   size_t extra_bytes,
340   const struct xnn_qu8_packing_params* params);
341 
342 XNN_INTERNAL void xnn_pack_qs8_conv_kgo_w(
343   size_t g,
344   size_t nc,
345   size_t ks,
346   size_t nr,
347   size_t kr,
348   size_t sr,
349   const int8_t* k,
350   const int32_t* b,
351   void* packed_w,
352   size_t extra_bytes,
353   const struct xnn_qs8_packing_params* params);
354 
355 
356 typedef void (*xnn_pack_deconv_goki_w_function)(
357   size_t g,
358   size_t nc,
359   size_t kh,
360   size_t kw,
361   size_t kc,
362   size_t sh,
363   size_t sw,
364   size_t nr,
365   size_t kr,
366   size_t sr,
367   const void* k,
368   const void* b,
369   void* packed_w,
370   struct subconvolution_params* subconv_params,
371   const void* params);
372 
373 XNN_INTERNAL void xnn_pack_f32_deconv_goki_w(
374   size_t g,
375   size_t nc,
376   size_t kh,
377   size_t kw,
378   size_t kc,
379   size_t sh,
380   size_t sw,
381   size_t nr,
382   size_t kr,
383   size_t sr,
384   const float* k,
385   const float* b,
386   float* packed_w,
387   struct subconvolution_params* subconv_params,
388   const void* params);
389 
390 XNN_INTERNAL void xnn_pack_f16_deconv_goki_w(
391   size_t g,
392   size_t nc,
393   size_t kh,
394   size_t kw,
395   size_t kc,
396   size_t sh,
397   size_t sw,
398   size_t nr,
399   size_t kr,
400   size_t sr,
401   const uint16_t* k,
402   const uint16_t* b,
403   uint16_t* packed_w,
404   struct subconvolution_params* subconv_params,
405   const void* params);
406 
407 XNN_INTERNAL void xnn_pack_f32_to_f16_deconv_goki_w(
408   size_t g,
409   size_t nc,
410   size_t kh,
411   size_t kw,
412   size_t kc,
413   size_t sh,
414   size_t sw,
415   size_t nr,
416   size_t kr,
417   size_t sr,
418   const float* k,
419   const float* b,
420   uint16_t* packed_w,
421   struct subconvolution_params* subconv_params,
422   const void* params);
423 
424 XNN_INTERNAL void xnn_pack_qs8_deconv_goki_w(
425   size_t g,
426   size_t nc,
427   size_t kh,
428   size_t kw,
429   size_t kc,
430   size_t sh,
431   size_t sw,
432   size_t nr,
433   size_t kr,
434   size_t sr,
435   const int8_t* k,
436   const int32_t* b,
437   void* packed_w,
438   struct subconvolution_params* subconv_params,
439   const struct xnn_qs8_packing_params* params);
440 
441 XNN_INTERNAL void xnn_pack_qu8_deconv_goki_w(
442   size_t g,
443   size_t nc,
444   size_t kh,
445   size_t kw,
446   size_t kc,
447   size_t sh,
448   size_t sw,
449   size_t nr,
450   size_t kr,
451   size_t sr,
452   const uint8_t* k,
453   const int32_t* b,
454   void* packed_w,
455   struct subconvolution_params* subconv_params,
456   const struct xnn_qu8_packing_params* params);
457 
458 
459 // Pack weights and bias such that:
460 // 1. Each block contains `cr` bias and `cr * h * w` weights.
461 // 2. Within each "cr block", `cr` biases are at the beginning of the block.
462 // 3. Weights are written such that the channel values at the same x-y is are adjacent in memory.
463 // 4. The weights are then written column major (WHC layout).
464 // "ghw" in the function name is the layout of the weights, (g)roups, (h)eight, (w)idth.
465 typedef void (*xnn_pack_dwconv_ghw_w_function)(
466   size_t primary_tile,
467   size_t h,
468   size_t w,
469   size_t c,
470   size_t cr,
471   const void* k,
472   const void* b,
473   void* packed_w,
474   size_t extra_bytes,
475   const void* params);
476 
477 XNN_INTERNAL void xnn_pack_f32_dwconv_ghw_w(
478   size_t primary_tile,
479   size_t h,
480   size_t w,
481   size_t c,
482   size_t cr,
483   const float* k,
484   const float* b,
485   float* packed_w,
486   size_t extra_bytes,
487   const void* params);
488 
489 XNN_INTERNAL void xnn_pack_f16_dwconv_ghw_w(
490   size_t primary_tile,
491   size_t h,
492   size_t w,
493   size_t c,
494   size_t cr,
495   const uint16_t* k,
496   const uint16_t* b,
497   uint16_t* packed_w,
498   size_t extra_bytes,
499   const void* params);
500 
501 XNN_INTERNAL void xnn_pack_f32_to_f16_dwconv_ghw_w(
502   size_t primary_tile,
503   size_t h,
504   size_t w,
505   size_t c,
506   size_t cr,
507   const float* k,
508   const float* b,
509   uint16_t* packed_w,
510   size_t extra_bytes,
511   const void* params);
512 
513 XNN_INTERNAL void xnn_pack_qu8_dwconv_ghw_w(
514   size_t primary_tile,
515   size_t h,
516   size_t w,
517   size_t c,
518   size_t cr,
519   const uint8_t* k,
520   const int32_t* b,
521   void* packed_w,
522   size_t extra_bytes,
523   const struct xnn_qu8_packing_params* params);
524 
525 XNN_INTERNAL void xnn_pack_qs8_dwconv_ghw_w(
526   size_t primary_tile,
527   size_t h,
528   size_t w,
529   size_t c,
530   size_t cr,
531   const int8_t* k,
532   const int32_t* b,
533   void* packed_w,
534   size_t extra_bytes,
535   const struct xnn_qs8_packing_params* params);
536 
537 
538 typedef void (*xnn_pack_dwconv_hwg_w_function)(
539   size_t primary_tile,
540   size_t h,
541   size_t w,
542   size_t c,
543   size_t cr,
544   const void* k,
545   const void* b,
546   void* packed_w,
547   size_t extra_bytes,
548   const void* params);
549 
550 XNN_INTERNAL void xnn_pack_f32_dwconv_hwg_w(
551   size_t primary_tile,
552   size_t h,
553   size_t w,
554   size_t c,
555   size_t cr,
556   const float* k,
557   const float* b,
558   float* packed_w,
559   size_t extra_bytes,
560   const void* params);
561 
562 XNN_INTERNAL void xnn_pack_f16_dwconv_hwg_w(
563   size_t primary_tile,
564   size_t h,
565   size_t w,
566   size_t c,
567   size_t cr,
568   const uint16_t* k,
569   const uint16_t* b,
570   uint16_t* packed_w,
571   size_t extra_bytes,
572   const void* params);
573 
574 XNN_INTERNAL void xnn_pack_f32_to_f16_dwconv_hwg_w(
575   size_t primary_tile,
576   size_t h,
577   size_t w,
578   size_t c,
579   size_t cr,
580   const float* k,
581   const float* b,
582   uint16_t* packed_w,
583   size_t extra_bytes,
584   const void* params);
585 
586 XNN_INTERNAL void xnn_pack_qu8_dwconv_hwg_w(
587   size_t primary_tile,
588   size_t h,
589   size_t w,
590   size_t c,
591   size_t cr,
592   const uint8_t* k,
593   const int32_t* b,
594   void* packed_w,
595   size_t extra_bytes,
596   const struct xnn_qu8_packing_params* params);
597 
598 XNN_INTERNAL void xnn_pack_qs8_dwconv_hwg_w(
599   size_t primary_tile,
600   size_t h,
601   size_t w,
602   size_t c,
603   size_t cr,
604   const int8_t* k,
605   const int32_t* b,
606   void* packed_w,
607   size_t extra_bytes,
608   const struct xnn_qs8_packing_params* params);
609 
610 
611 XNN_INTERNAL void xnn_pack_f32_gemminc_goi_w(
612   size_t g,
613   size_t nc,
614   size_t kc,
615   size_t nr,
616   size_t kr,
617   size_t sr,
618   const float* k,
619   float* packed_w,
620   const void* params);
621 
622 XNN_INTERNAL void xnn_pack_f16_gemminc_goi_w(
623   size_t g,
624   size_t nc,
625   size_t kc,
626   size_t nr,
627   size_t kr,
628   size_t sr,
629   const uint16_t* k,
630   uint16_t* packed_w,
631   const void* params);
632 
633 
634 XNN_INTERNAL void xnn_pack_f32_dconv_oki_w(
635   size_t nc,
636   size_t kc,
637   size_t nr,
638   size_t kh,
639   size_t kw,
640   const float* k,
641   const float* b,
642   float* packed_w,
643   const void* params);
644 
645 XNN_INTERNAL void xnn_pack_f16_dconv_oki_w(
646   size_t nc,
647   size_t kc,
648   size_t nr,
649   size_t kh,
650   size_t kw,
651   const uint16_t* k,
652   const uint16_t* b,
653   uint16_t* packed_w,
654   const void* params);
655 
656 
657 XNN_INTERNAL void xnn_pack_f32_chw_dwconv_ghw_w(
658   size_t kernel_size,
659   size_t groups,
660   const float* kernel,
661   const float* bias,
662   float* packed_weights,
663   const void* params);
664 
665 XNN_INTERNAL void xnn_pack_f16_chw_dwconv_ghw_w(
666   size_t kernel_size,
667   size_t groups,
668   const uint16_t* kernel,
669   const uint16_t* bias,
670   uint16_t* packed_weights,
671   const void* params);
672 
673 
674 XNN_INTERNAL void xnn_pack_f32_chw_dwconv_hwg_w(
675   size_t kernel_size,
676   size_t groups,
677   const float* kernel,
678   const float* bias,
679   float* packed_weights,
680   const void* params);
681 
682 
683 typedef void (*xnn_pack_vmulcaddc_w_function)(
684   size_t c,
685   size_t cr,
686   const void* s,
687   const void* b,
688   void* packed_w,
689   const void* params);
690 
691 XNN_INTERNAL void xnn_pack_f32_vmulcaddc_w(
692   size_t c,
693   size_t cr,
694   const float* s,
695   const float* b,
696   float* packed_w,
697   const void* params);
698 
699 XNN_INTERNAL void xnn_pack_f16_vmulcaddc_w(
700   size_t c,
701   size_t cr,
702   const uint16_t* s,
703   const uint16_t* b,
704   uint16_t* packed_w,
705   const void* params);
706 
707 XNN_INTERNAL void xnn_pack_f32_to_f16_vmulcaddc_w(
708   size_t c,
709   size_t cr,
710   const float* s,
711   const float* b,
712   uint16_t* packed_w,
713   const void* params);
714 
715 
716 typedef void (*xnn_pack_prelu_w_function)(
717   size_t c,
718   const void* s,
719   void* packed_w);
720 
721 XNN_INTERNAL void xnn_pack_f32_prelu_w(
722   size_t c,
723   const float* s,
724   float* packed_w);
725 
726 XNN_INTERNAL void xnn_pack_f16_prelu_w(
727   size_t c,
728   const uint16_t* s,
729   uint16_t* packed_w);
730 
731 XNN_INTERNAL void xnn_pack_f32_to_f16_prelu_w(
732   size_t c,
733   const float* s,
734   uint16_t* packed_w);
735 
736 
737 #ifdef __cplusplus
738 }  // extern "C"
739 #endif
740