xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/operator-run.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 #include <stddef.h>
11 #include <stdint.h>
12 #include <string.h>
13 
14 #include <pytorch_qnnpack.h>
15 #include <qnnpack/common.h>
16 #include <qnnpack/log.h>
17 #include <qnnpack/math.h>
18 #include <qnnpack/operator.h>
19 #include <qnnpack/params.h>
20 
21 #ifdef _MSC_VER
22 #include <malloc.h>
23 #endif
24 
25 struct q8gemm_context {
26   size_t k;
27   size_t k_stride;
28   size_t n;
29   size_t n_stride;
30   const uint8_t* a;
31   size_t a_stride;
32   const uint8_t* packed_w;
33   uint8_t* c;
34   size_t c_stride;
35   union pytorch_qnnp_conv_quantization_params quantization_params;
36   const pytorch_q8gemm_ukernel_function ukernel;
37 };
38 
compute_q8gemm(const struct q8gemm_context context[RESTRICT_STATIC1],size_t group_index,size_t pixel_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t pixel_range,size_t mr_block_size,size_t nr_block_size)39 static void compute_q8gemm(
40     const struct q8gemm_context context[RESTRICT_STATIC 1],
41     size_t group_index,
42     size_t pixel_index,
43     size_t mr_block_start,
44     size_t nr_block_start,
45     size_t group_range /* always 1 */,
46     size_t pixel_range,
47     size_t mr_block_size,
48     size_t nr_block_size) {
49   const size_t k = context->k;
50   const size_t k_stride = context->k_stride;
51   const size_t n = context->n;
52   const size_t n_stride = context->n_stride;
53   const uint8_t* restrict a = context->a;
54   const size_t a_stride = context->a_stride;
55   const void* restrict packed_w = context->packed_w;
56   uint8_t* restrict c = context->c;
57   const size_t c_stride = context->c_stride;
58 
59   size_t output_channel_index = nr_block_start + group_index * n;
60   context->ukernel(
61       mr_block_size,
62       nr_block_size,
63       k,
64       a + (pixel_index + mr_block_start) * a_stride + group_index * k,
65       a_stride,
66       (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (k_stride * sizeof(uint8_t) + sizeof(int32_t))),
67       c + (pixel_index + mr_block_start) * c_stride + nr_block_start +
68           group_index * n,
69       c_stride,
70       output_channel_index,
71       &context->quantization_params);
72 }
73 
74 // At the moment we opt to remove sparse kernels that
75 // dont require prepacking as their perf was always
76 // worse.
77 #ifdef NO_PREPACK_SPARSE_KERNEL
78 struct q8gemm_sparse_dq_context {
79   const uint8_t* a;
80   size_t a_stride;
81   const uint32_t* kernel_col_indices;
82   const uint32_t* kernel_row_values;
83   const uint8_t* kernel_values;
84   const float* bias;
85   float* c;  // can be float or uint8)t
86   size_t c_stride;
87   struct pytorch_qnnp_conv_dynamic_quantization_params quantization_params;
88   const pytorch_q8gemm_dq_sparse_ukernel_function ukernel;
89 };
90 
compute_q8gemm_sparse_dq(const struct q8gemm_sparse_dq_context context[RESTRICT_STATIC1],size_t group_index,size_t pixel_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t pixel_range,size_t mr_block_size,size_t nr_block_size)91 static void compute_q8gemm_sparse_dq(
92     const struct q8gemm_sparse_dq_context context[RESTRICT_STATIC 1],
93     size_t group_index, /* ignored */
94     size_t pixel_index, /* ignored */
95     size_t mr_block_start,
96     size_t nr_block_start,
97     size_t group_range /* always 1 */,
98     size_t pixel_range,
99     size_t mr_block_size,
100     size_t nr_block_size) {
101   const uint8_t* restrict a = context->a;
102   const size_t a_stride = context->a_stride;
103   float* restrict c = (float*)context->c;
104   const size_t c_stride = context->c_stride;
105 
106   size_t output_channel_index = nr_block_start;
107   context->ukernel(
108       mr_block_size,
109       nr_block_size,
110       a + mr_block_start * a_stride,
111       a_stride,
112       context->kernel_values,
113       context->kernel_row_values + nr_block_start,
114       context->kernel_col_indices,
115       context->bias + nr_block_start,
116       c + mr_block_start * c_stride + nr_block_start,
117       c_stride,
118       output_channel_index,
119       &context->quantization_params);
120 }
121 #endif
122 
123 struct q8gemm_prepackA_sparse_dq_context {
124   size_t k;
125   const uint8_t* a;
126   size_t a_stride;
127   uint8_t* a_packed;
128   size_t a_packed_stride;
129   size_t log2_mr;
130   size_t log2_row_block_size;
131   union {
132     const uint32_t* kernel_col_indices_w32;
133     const uint16_t* kernel_col_indices_w16;
134     const uint8_t* kernel_col_indices_w8;
135   };
136   union {
137     const uint32_t* kernel_row_values_w32;
138     const uint16_t* kernel_row_values_w16;
139     const uint8_t* kernel_row_values_w8;
140   };
141   enum pytorch_qnnp_sparse_matrix_indices_dtype kernel_indices_dtype;
142   const uint8_t* kernel_values;
143   const float* bias;
144   float* c;  // can be float or uint8)t
145   size_t c_stride;
146   struct pytorch_qnnp_conv_dynamic_quantization_params quantization_params;
147   union {
148     // Not const because assigned after context is initialized
149     pytorch_q8gemm_dq_sparse_packedA_w32_ukernel_function ukernel_w32;
150     pytorch_q8gemm_dq_sparse_packedA_w16_ukernel_function ukernel_w16;
151     pytorch_q8gemm_dq_sparse_packedA_w8_ukernel_function ukernel_w8;
152   };
153   const pytorch_q8gemm_sparse_packA_ukernel_function prepack_ukernel;
154 };
155 
compute_q8gemm_prepack_a_sparse(const struct q8gemm_prepackA_sparse_dq_context context[RESTRICT_STATIC1],size_t group_index,size_t pixel_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t pixel_range,size_t mr_block_size,size_t nr_block_size)156 static void compute_q8gemm_prepack_a_sparse(
157     const struct q8gemm_prepackA_sparse_dq_context context[RESTRICT_STATIC 1],
158     size_t group_index, /* ignored */
159     size_t pixel_index, /* ignored */
160     size_t mr_block_start,
161     size_t nr_block_start,
162     size_t group_range /* always 1 */,
163     size_t pixel_range,
164     size_t mr_block_size,
165     size_t nr_block_size) {
166   const uint8_t* restrict a = context->a;
167   const size_t a_stride = context->a_stride;
168   const size_t mr_packed_block_start =
169     ((mr_block_start >> context->log2_mr) * context->a_packed_stride);
170 
171   context->prepack_ukernel(
172       mr_block_size,
173       context->k,
174       a + mr_block_start * a_stride,
175       a_stride,
176       context->a_packed + mr_packed_block_start);
177 }
178 
compute_q8gemm_prepacked_sparse_dq(const struct q8gemm_prepackA_sparse_dq_context context[RESTRICT_STATIC1],size_t group_index,size_t pixel_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t pixel_range,size_t mr_block_size,size_t nr_block_size)179 static void compute_q8gemm_prepacked_sparse_dq(
180     const struct q8gemm_prepackA_sparse_dq_context context[RESTRICT_STATIC 1],
181     size_t group_index, /* ignored */
182     size_t pixel_index, /* ignored */
183     size_t mr_block_start,
184     size_t nr_block_start,
185     size_t group_range /* always 1 */,
186     size_t pixel_range,
187     size_t mr_block_size,
188     size_t nr_block_size) {
189   const size_t mr_packed_block_start =
190     ((mr_block_start >> context->log2_mr) * context->a_packed_stride);
191   const uint8_t* restrict a_packed = context->a_packed + mr_packed_block_start;
192   const size_t c_stride = context->c_stride;
193   float* restrict c =
194       ((float*)context->c) + mr_block_start * c_stride + nr_block_start;
195   const size_t kernel_row_values_shift =
196       nr_block_start >> context->log2_row_block_size;
197   const float* bias = context->bias + nr_block_start;
198   const size_t output_channel_index = nr_block_start;
199 
200   switch (context->kernel_indices_dtype) {
201     case pytorch_qnnp_sparse_matrix_indices_dtype_uint32_t:
202       context->ukernel_w32(
203           mr_block_size,
204           nr_block_size,
205           a_packed,
206           context->kernel_values,
207           context->kernel_row_values_w32 + kernel_row_values_shift,
208           context->kernel_col_indices_w32,
209           bias,
210           c,
211           c_stride,
212           output_channel_index,
213           &context->quantization_params);
214       break;
215     case pytorch_qnnp_sparse_matrix_indices_dtype_uint16_t:
216       context->ukernel_w16(
217           mr_block_size,
218           nr_block_size,
219           a_packed,
220           context->kernel_values,
221           context->kernel_row_values_w16 + kernel_row_values_shift,
222           context->kernel_col_indices_w16,
223           bias,
224           c,
225           c_stride,
226           output_channel_index,
227           &context->quantization_params);
228       break;
229     case pytorch_qnnp_sparse_matrix_indices_dtype_uint8_t:
230       context->ukernel_w8(
231           mr_block_size,
232           nr_block_size,
233           a_packed,
234           context->kernel_values,
235           context->kernel_row_values_w8 + kernel_row_values_shift,
236           context->kernel_col_indices_w8,
237           bias,
238           c,
239           c_stride,
240           output_channel_index,
241           &context->quantization_params);
242       break;
243     case pytorch_qnnp_sparse_matrix_indices_dtype_invalid:
244       // This function can not return an error code without substantially
245       // changing the internal API. A check for invalid index type should
246       // already exist in the calling function. If the code reaches here, then
247       // please add / restore the index check in the calling function.
248       pytorch_qnnp_log_error(
249           "Invalid indices dtype specified for "
250           "operator-run compute_q8gemm_prepacked_sparse_dq");
251       assert(false);
252   }
253 }
254 
255 struct q8sum_rows_context {
256   const uint8_t* a;
257   size_t groups;
258   size_t m;
259   size_t k;
260   size_t a_stride;
261   const int32_t multiplier;
262   int32_t* a_sum;
263   size_t a_sum_stride;
264   const pytorch_q8sum_rows_ukernel_function ukernel;
265 };
266 
compute_sum_rows(const struct q8sum_rows_context context[RESTRICT_STATIC1],size_t group_index,size_t batch_index,size_t block_start,size_t group_range,size_t batch_range,size_t block_size)267 static void compute_sum_rows(
268     const struct q8sum_rows_context context[RESTRICT_STATIC 1],
269     size_t group_index,
270     size_t batch_index,
271     size_t block_start,
272     size_t group_range /* always 1 */,
273     size_t batch_range /* always 1 */,
274     size_t block_size) {
275   const uint8_t* a = context->a;
276   const size_t groups = context->groups;
277   const size_t m = context->m;
278   const size_t k = context->k;
279   const size_t a_stride = context->a_stride;
280   const int32_t multiplier = context->multiplier;
281   int32_t* a_sum = context->a_sum;
282   const size_t a_sum_stride = context->a_sum_stride;
283 
284   context->ukernel(
285       a + batch_index * m * a_stride + group_index * k + block_start * a_stride,
286       min(block_size, m - block_start),
287       k,
288       a_stride,
289       multiplier,
290       a_sum + batch_index * groups * a_sum_stride + group_index * a_sum_stride +
291           block_start);
292 }
293 
294 struct q8gemm_xzp_context {
295   size_t k;
296   size_t k_stride;
297   size_t n;
298   size_t n_stride;
299   const uint8_t* a;
300   size_t a_stride;
301   const void* packed_w;
302   uint8_t* c;
303   size_t c_stride;
304   const int32_t* a_sum;
305   size_t groups;
306   size_t batch_size;
307   size_t a_sum_stride;
308   union pytorch_qnnp_q31_requantization_params requantization_params;
309   const pytorch_q8gemm_xzp_ukernel_function ukernel;
310 };
311 
compute_q8gemm_xzp(const struct q8gemm_xzp_context context[RESTRICT_STATIC1],size_t group_index,size_t pixel_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t pixel_range,size_t mr_block_size,size_t nr_block_size)312 static void compute_q8gemm_xzp(
313     const struct q8gemm_xzp_context context[RESTRICT_STATIC 1],
314     size_t group_index,
315     size_t pixel_index,
316     size_t mr_block_start,
317     size_t nr_block_start,
318     size_t group_range /* always 1 */,
319     size_t pixel_range,
320     size_t mr_block_size,
321     size_t nr_block_size) {
322   const size_t k = context->k;
323   const size_t k_stride = context->k_stride;
324   const size_t n = context->n;
325   const size_t n_stride = context->n_stride;
326   const uint8_t* restrict a = context->a;
327   const size_t a_stride = context->a_stride;
328   const void* restrict packed_w = context->packed_w;
329   uint8_t* restrict c = context->c;
330   const size_t c_stride = context->c_stride;
331   const int32_t* a_sum = context->a_sum;
332   const size_t groups = context->groups;
333   const size_t a_sum_stride = context->a_sum_stride;
334 
335   context->ukernel(
336       mr_block_size,
337       nr_block_size,
338       k,
339       a + (pixel_index + mr_block_start) * a_stride + group_index * k,
340       a_stride,
341       a_sum + pixel_index * groups + group_index * a_sum_stride +
342           mr_block_start,
343       (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (k_stride * sizeof(uint8_t) + sizeof(int32_t))),
344       c + (pixel_index + mr_block_start) * c_stride + nr_block_start +
345           group_index * n,
346       c_stride,
347       &context->requantization_params);
348 }
349 
350 struct q8conv_context {
351   size_t bs;
352   size_t ks;
353   size_t kc;
354   size_t kc_stride;
355   size_t m;
356   size_t m_stride;
357   size_t n;
358   size_t n_stride;
359   const uint8_t** indirect_a;
360   const void* packed_w;
361   uint8_t* c;
362   size_t c_stride;
363   union pytorch_qnnp_conv_quantization_params quantization_params;
364   const pytorch_q8conv_ukernel_function ukernel;
365 };
366 
compute_q8conv(const struct q8conv_context context[RESTRICT_STATIC1],size_t group_index,size_t image_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t image_range,size_t mr_block_size,size_t nr_block_size)367 static void compute_q8conv(
368     const struct q8conv_context context[RESTRICT_STATIC 1],
369     size_t group_index,
370     size_t image_index,
371     size_t mr_block_start,
372     size_t nr_block_start,
373     size_t group_range /* always 1 */,
374     size_t image_range /* always 1 */,
375     size_t mr_block_size,
376     size_t nr_block_size) {
377   const size_t bs = context->bs;
378   const size_t ks = context->ks;
379   const size_t kc = context->kc;
380   const size_t kc_stride = context->kc_stride;
381   const size_t m = context->m;
382   const size_t m_stride = context->m_stride;
383   const size_t n = context->n;
384   const size_t n_stride = context->n_stride;
385   const uint8_t** restrict indirect_a = context->indirect_a;
386   const void* restrict packed_w = context->packed_w;
387   uint8_t* restrict c = context->c;
388   const size_t c_stride = context->c_stride;
389 
390   size_t output_channel_index = nr_block_start + group_index * n;
391   context->ukernel(
392       mr_block_size,
393       nr_block_size,
394       kc,
395       ks,
396       indirect_a +
397           (mr_block_start + (image_index + group_index * bs) * m_stride) * ks,
398       (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (kc_stride * sizeof(uint8_t) + sizeof(int32_t))),
399       c + (mr_block_start + image_index * m) * c_stride + group_index * n +
400           nr_block_start,
401       c_stride,
402       output_channel_index,
403       &context->quantization_params);
404 }
405 
406 struct q8dwconv2d_context {
407   size_t groups;
408   size_t group_stride;
409   const uint8_t** indirection_buffer;
410   size_t indirection_buffer_row_stride;
411   size_t indirection_buffer_col_stride;
412   const void* packed_weights;
413   uint8_t* output;
414   size_t output_height;
415   size_t output_width;
416   size_t output_row_stride;
417   size_t output_col_increment;
418   union pytorch_qnnp_conv_quantization_params quantization_params;
419   union {
420     const pytorch_q8dwconv2d_up_ukernel_function unipass_ukernel;
421     const pytorch_q8dwconv2d_mp_ukernel_function multipass_ukernel;
422   };
423 };
424 
425 struct q8dwconv3d_context {
426   size_t groups;
427   size_t group_stride;
428   const uint8_t** indirection_buffer;
429   size_t indirection_buffer_slice_stride;
430   size_t indirection_buffer_row_stride;
431   size_t indirection_buffer_col_stride;
432   const void* packed_weights;
433   uint8_t* output;
434   size_t output_depth;
435   size_t output_height;
436   size_t output_width;
437   size_t output_slice_stride;
438   union pytorch_qnnp_conv_quantization_params quantization_params;
439   const pytorch_q8dwconv3d_mp_ukernel_function multipass_ukernel;
440 };
441 
compute_dwconv2d_unipass(const struct q8dwconv2d_context context[RESTRICT_STATIC1],size_t image,size_t output_y)442 static void compute_dwconv2d_unipass(
443     const struct q8dwconv2d_context context[RESTRICT_STATIC 1],
444     size_t image,
445     size_t output_y) {
446   const size_t output_height = context->output_height;
447 
448   context->unipass_ukernel(
449       context->groups,
450       context->output_width,
451       context->indirection_buffer +
452           (image * output_height + output_y) *
453               context->indirection_buffer_row_stride,
454       context->packed_weights,
455       context->output +
456           (image * output_height + output_y) * context->output_row_stride,
457       context->indirection_buffer_col_stride,
458       context->output_col_increment,
459       &context->quantization_params);
460 }
461 
compute_dwconv2d_multiipass(const struct q8dwconv2d_context context[RESTRICT_STATIC1],size_t image,size_t output_y)462 static void compute_dwconv2d_multiipass(
463     const struct q8dwconv2d_context context[RESTRICT_STATIC 1],
464     size_t image,
465     size_t output_y) {
466   const size_t output_height = context->output_height;
467   PYTORCH_QNNP_ALIGN(16)
468 #ifdef _MSC_VER
469   int32_t* multipass_acc = _malloca(sizeof(int32_t) * context->group_stride);
470 #else
471   int32_t multipass_acc[context->group_stride];
472 #endif
473 
474   context->multipass_ukernel(
475       context->groups,
476       context->output_width,
477       context->indirection_buffer +
478           (image * output_height + output_y) *
479               context->indirection_buffer_row_stride,
480       context->packed_weights,
481       multipass_acc,
482       context->output +
483           (image * output_height + output_y) * context->output_row_stride,
484       context->indirection_buffer_col_stride,
485       context->output_col_increment,
486       &context->quantization_params);
487 
488 #ifdef _MSC_VER
489   _freea(multipass_acc);
490 #endif
491 }
492 
compute_dwconv3d_multiipass(const struct q8dwconv3d_context context[1],size_t image,size_t output_z)493 static void compute_dwconv3d_multiipass(
494     const struct q8dwconv3d_context context[1],
495     size_t image,
496     size_t output_z) {
497   const size_t output_depth = context->output_depth;
498   PYTORCH_QNNP_ALIGN(16)
499 #ifdef _MSC_VER
500   int32_t* multipass_acc =
501       (int32_t*)_malloca(sizeof(int32_t) * context->group_stride);
502 #else
503   int32_t multipass_acc[context->group_stride];
504 #endif
505 
506   context->multipass_ukernel(
507       context->groups,
508       context->output_height,
509       context->output_width,
510       context->indirection_buffer +
511           (image * output_depth + output_z) *
512               context->indirection_buffer_slice_stride,
513       context->packed_weights,
514       multipass_acc,
515       context->output +
516           (image * output_depth + output_z) * context->output_slice_stride,
517       context->indirection_buffer_row_stride,
518       context->indirection_buffer_col_stride,
519       0,
520       &context->quantization_params);
521 
522 #ifdef _MSC_VER
523   _freea(multipass_acc);
524 #endif
525 }
526 
527 struct max_pooling_context {
528   const void** indirect_input;
529   size_t indirect_input_batch_stride;
530   size_t indirect_input_height_stride;
531   void* output;
532   size_t output_batch_stride;
533   size_t output_height_stride;
534   size_t output_width;
535   size_t pooling_size;
536   size_t channels;
537   size_t input_increment;
538   size_t output_increment;
539   union pytorch_qnnp_u8_clamping_params params;
540   pytorch_u8maxpool_ukernel_function ukernel;
541 };
542 
compute_max_pooling(const struct max_pooling_context context[RESTRICT_STATIC1],size_t batch_index,size_t output_y)543 static void compute_max_pooling(
544     const struct max_pooling_context context[RESTRICT_STATIC 1],
545     size_t batch_index,
546     size_t output_y) {
547   const void** indirect_input =
548     (const void**) ((uintptr_t) context->indirect_input +
549       batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
550   void* output =
551     (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
552 
553   context->ukernel(
554       context->output_width,
555       context->pooling_size,
556       context->channels,
557       (const uint8_t**)indirect_input,
558       output,
559       context->input_increment,
560       context->output_increment,
561       &context->params);
562 }
563 
564 struct average_pooling_context {
565   const void** indirect_input;
566   size_t indirect_input_batch_stride;
567   size_t indirect_input_height_stride;
568   void* output;
569   size_t output_batch_stride;
570   size_t output_height_stride;
571   size_t output_width;
572   size_t pooling_size;
573   size_t channels;
574   size_t packed_channels;
575   const void* zero;
576   size_t input_increment;
577   size_t output_increment;
578   union pytorch_qnnp_avgpool_quantization_params quantization_params;
579   union {
580     pytorch_q8avgpool_up_ukernel_function unipass_ukernel;
581     pytorch_q8avgpool_mp_ukernel_function multipass_ukernel;
582   };
583 };
584 
compute_average_pooling_unipass(const struct average_pooling_context context[RESTRICT_STATIC1],size_t batch_index,size_t output_y)585 static void compute_average_pooling_unipass(
586     const struct average_pooling_context context[RESTRICT_STATIC 1],
587     size_t batch_index,
588     size_t output_y) {
589   const void** indirect_input =
590     (const void**) ((uintptr_t) context->indirect_input +
591       batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
592   void* output =
593     (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
594 
595   context->unipass_ukernel(
596       context->output_width,
597       context->pooling_size,
598       context->channels,
599       (const uint8_t**)indirect_input,
600       context->zero,
601       output,
602       context->input_increment,
603       context->output_increment,
604       &context->quantization_params);
605 }
606 
compute_average_pooling_multipass(const struct average_pooling_context context[RESTRICT_STATIC1],size_t batch_index,size_t output_y)607 static void compute_average_pooling_multipass(
608     const struct average_pooling_context context[RESTRICT_STATIC 1],
609     size_t batch_index,
610     size_t output_y) {
611   const void** indirect_input =
612     (const void**) ((uintptr_t) context->indirect_input +
613       batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
614   void* output =
615     (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
616   PYTORCH_QNNP_ALIGN(16)
617 #ifdef _MSC_VER
618   int32_t* multipass_buffer =
619       _malloca(sizeof(int32_t) * context->packed_channels);
620 #else
621   int32_t multipass_buffer[context->packed_channels];
622 #endif
623 
624   context->multipass_ukernel(
625       context->output_width,
626       context->pooling_size,
627       context->channels,
628       (const uint8_t**)indirect_input,
629       context->zero,
630       multipass_buffer,
631       output,
632       context->input_increment,
633       context->output_increment,
634       &context->quantization_params);
635 
636 #ifdef _MSC_VER
637   _freea(multipass_buffer);
638 #endif
639 }
640 
641 struct global_average_pooling_context {
642   const void* input;
643   const void* zero;
644   size_t input_pixel_stride;
645   size_t input_batch_stride;
646   size_t input_elements;
647   size_t channels;
648   size_t packed_channels;
649   void* output;
650   size_t output_batch_stride;
651   union pytorch_qnnp_avgpool_quantization_params quantization_params;
652   union {
653     pytorch_q8gavgpool_up_ukernel_function unipass_ukernel;
654     pytorch_q8gavgpool_mp_ukernel_function multipass_ukernel;
655   };
656 };
657 
compute_global_average_pooling_unipass(const struct global_average_pooling_context context[RESTRICT_STATIC1],size_t batch_index)658 static void compute_global_average_pooling_unipass(
659     const struct global_average_pooling_context context[RESTRICT_STATIC 1],
660     size_t batch_index) {
661   const void* input =
662       (const void*)((uintptr_t)context->input + batch_index * context->input_batch_stride);
663   void* output =
664       (void*)((uintptr_t)context->output + batch_index * context->output_batch_stride);
665 
666   context->unipass_ukernel(
667       context->input_elements,
668       context->channels,
669       input,
670       context->input_pixel_stride,
671       context->zero,
672       output,
673       &context->quantization_params);
674 }
675 
compute_global_average_pooling_multipass(const struct global_average_pooling_context context[RESTRICT_STATIC1],size_t batch_index)676 static void compute_global_average_pooling_multipass(
677     const struct global_average_pooling_context context[RESTRICT_STATIC 1],
678     size_t batch_index) {
679   const void* input =
680       (const void*)((uintptr_t)context->input + batch_index * context->input_batch_stride);
681   void* output =
682       (void*)((uintptr_t)context->output + batch_index * context->output_batch_stride);
683   PYTORCH_QNNP_ALIGN(16)
684 #ifdef _MSC_VER
685   int32_t* multipass_buffer =
686       _malloca(sizeof(int32_t) * context->packed_channels);
687 #else
688   int32_t multipass_buffer[context->packed_channels];
689 #endif
690 
691   context->multipass_ukernel(
692       context->input_elements,
693       context->channels,
694       input,
695       context->input_pixel_stride,
696       context->zero,
697       multipass_buffer,
698       output,
699       &context->quantization_params);
700 
701 #ifdef _MSC_VER
702   _freea(multipass_buffer);
703 #endif
704 }
705 
706 struct q8add_strided_context {
707   size_t n;
708   const uint8_t* a;
709   size_t a_stride;
710   const uint8_t* b;
711   size_t b_stride;
712   const uint8_t* y;
713   size_t y_stride;
714   union pytorch_qnnp_add_quantization_params quantization_params;
715   pytorch_q8vadd_ukernel_function ukernel;
716 };
717 
compute_q8add_strided(const struct q8add_strided_context context[RESTRICT_STATIC1],size_t batch_offset,size_t batch_range)718 static void compute_q8add_strided(
719     const struct q8add_strided_context context[RESTRICT_STATIC 1],
720     size_t batch_offset,
721     size_t batch_range /* always 1 */) {
722   assert(batch_range == 1);
723 
724   const size_t n = context->n;
725   const size_t a_stride = context->a_stride;
726   const size_t b_stride = context->b_stride;
727   const size_t y_stride = context->y_stride;
728   const void* a =
729       (const void*)((uintptr_t)context->a + a_stride * batch_offset);
730   const void* b =
731       (const void*)((uintptr_t)context->b + b_stride * batch_offset);
732   void* y = (void*)((uintptr_t)context->y + y_stride * batch_offset);
733 
734   context->ukernel(n, a, b, y, &context->quantization_params);
735 }
736 
737 struct q8add_contiguous_context {
738   const uint8_t* a;
739   const uint8_t* b;
740   uint8_t* y;
741   union pytorch_qnnp_add_quantization_params quantization_params;
742   pytorch_q8vadd_ukernel_function ukernel;
743 };
744 
compute_q8add_contiguous(const struct q8add_contiguous_context context[RESTRICT_STATIC1],size_t offset,size_t size)745 static void compute_q8add_contiguous(
746     const struct q8add_contiguous_context context[RESTRICT_STATIC 1],
747     size_t offset,
748     size_t size) {
749   const void* a = (const void*)((uintptr_t)context->a + offset);
750   const void* b = (const void*)((uintptr_t)context->b + offset);
751   void* y = (void*)((uintptr_t)context->y + offset);
752   context->ukernel(size, a, b, y, &context->quantization_params);
753 }
754 
755 struct channel_shuffle_context {
756   const void* x;
757   size_t x_stride;
758   void* y;
759   size_t y_stride;
760   size_t n;
761   size_t m;
762   union {
763     pytorch_xzipc_ukernel_function fixed_ukernel;
764     pytorch_xzipv_ukernel_function variable_ukernel;
765   };
766 };
767 
compute_channel_shuffle_fixed(const struct channel_shuffle_context context[RESTRICT_STATIC1],size_t index)768 static void compute_channel_shuffle_fixed(
769     const struct channel_shuffle_context context[RESTRICT_STATIC 1],
770     size_t index) {
771   const void* x =
772       (const void*)((uintptr_t)context->x + index * context->x_stride);
773   void* y = (void*)((uintptr_t)context->y + index * context->y_stride);
774 
775   context->fixed_ukernel(context->n, x, y);
776 }
777 
compute_channel_shuffle_variable(const struct channel_shuffle_context context[RESTRICT_STATIC1],size_t index)778 static void compute_channel_shuffle_variable(
779     const struct channel_shuffle_context context[RESTRICT_STATIC 1],
780     size_t index) {
781   const void* x =
782       (const void*)((uintptr_t)context->x + index * context->x_stride);
783   void* y = (void*)((uintptr_t)context->y + index * context->y_stride);
784 
785   context->variable_ukernel(context->n, context->m, x, y);
786 }
787 
788 struct lut_strided_context {
789   size_t n;
790   const void* x;
791   size_t x_stride;
792   const void* t;
793   void* y;
794   size_t y_stride;
795   pytorch_x8lut_ukernel_function ukernel;
796 };
797 
compute_lut_strided(const struct lut_strided_context context[RESTRICT_STATIC1],size_t batch_index)798 static void compute_lut_strided(
799     const struct lut_strided_context context[RESTRICT_STATIC 1],
800     size_t batch_index) {
801   const void* x =
802       (const void*)((uintptr_t)context->x + context->x_stride * batch_index);
803   void* y = (void*)((uintptr_t)context->y + context->y_stride * batch_index);
804 
805   context->ukernel(context->n, x, context->t, y);
806 }
807 
808 struct lut_contiguous_context {
809   const void* x;
810   size_t x_stride;
811   const void* t;
812   void* y;
813   size_t y_stride;
814   pytorch_x8lut_ukernel_function ukernel;
815 };
816 
compute_lut_contiguous(const struct lut_contiguous_context context[RESTRICT_STATIC1],size_t offset,size_t size)817 static void compute_lut_contiguous(
818     const struct lut_contiguous_context context[RESTRICT_STATIC 1],
819     size_t offset,
820     size_t size) {
821   const void* x = (const void*)((uintptr_t)context->x + offset);
822   void* y = (void*)((uintptr_t)context->y + offset);
823 
824   context->ukernel(size, x, context->t, y);
825 }
826 
827 struct clamp_strided_context {
828   size_t n;
829   const void* x;
830   size_t x_stride;
831   void* y;
832   size_t y_stride;
833   pytorch_u8clamp_ukernel_function ukernel;
834   union pytorch_qnnp_u8_clamping_params params;
835 };
836 
compute_clamp_strided(const struct clamp_strided_context context[RESTRICT_STATIC1],size_t batch_index)837 static void compute_clamp_strided(
838     const struct clamp_strided_context context[RESTRICT_STATIC 1],
839     size_t batch_index) {
840   const void* x =
841       (const void*)((uintptr_t)context->x + context->x_stride * batch_index);
842   void* y = (void*)((uintptr_t)context->y + context->y_stride * batch_index);
843   context->ukernel(context->n, x, y, &context->params);
844 }
845 
846 struct clamp_contiguous_context {
847   const void* x;
848   size_t x_stride;
849   void* y;
850   size_t y_stride;
851   pytorch_u8clamp_ukernel_function ukernel;
852   union pytorch_qnnp_u8_clamping_params params;
853 };
854 
compute_clamp_contiguous(const struct clamp_contiguous_context context[RESTRICT_STATIC1],size_t offset,size_t size)855 static void compute_clamp_contiguous(
856     const struct clamp_contiguous_context context[RESTRICT_STATIC 1],
857     size_t offset,
858     size_t size) {
859   const void* x = (const void*)((uintptr_t)context->x + offset);
860   void* y = (void*)((uintptr_t)context->y + offset);
861   context->ukernel(size, x, y, &context->params);
862 }
863 
864 struct u8softargmax_context {
865   size_t n;
866   const uint8_t* x;
867   size_t x_stride;
868   const uint32_t* t;
869   uint8_t* y;
870   size_t y_stride;
871   pytorch_u8rmax_ukernel_function rmax_ukernel;
872   pytorch_u8lut32norm_ukernel_function lut_norm_ukernel;
873 };
874 
compute_u8softargmax(const struct u8softargmax_context context[RESTRICT_STATIC1],size_t batch_index)875 static void compute_u8softargmax(
876     const struct u8softargmax_context context[RESTRICT_STATIC 1],
877     size_t batch_index) {
878   const uint8_t* x =
879       (const uint8_t*)((uintptr_t)context->x + context->x_stride * batch_index);
880   uint8_t* y =
881       (uint8_t*)((uintptr_t)context->y + context->y_stride * batch_index);
882   const size_t n = context->n;
883 
884   const uint8_t x_max = context->rmax_ukernel(n, x);
885   const size_t adjustment = x_max ^ 255;
886   const uint32_t* t = (const uint32_t*)context->t + adjustment;
887   context->lut_norm_ukernel(n, x, t, y);
888 }
889 
pytorch_qnnp_run_operator(pytorch_qnnp_operator_t op,pthreadpool_t threadpool)890 enum pytorch_qnnp_status pytorch_qnnp_run_operator(
891     pytorch_qnnp_operator_t op,
892     pthreadpool_t threadpool) {
893   // For any ukernel type, there is no work to do if the batch size is 0.
894   if (op->batch_size == 0) {
895     return pytorch_qnnp_status_success;
896   }
897 
898   switch (op->ukernel_type) {
899     case pytorch_qnnp_ukernel_type_dwconv: {
900       const size_t batch_size = op->batch_size;
901       const size_t groups = op->groups;
902       const size_t kernel_depth = op->kernel_depth;
903       const size_t kernel_height = op->kernel_height;
904       const size_t kernel_width = op->kernel_width;
905       const size_t kernel_size = kernel_depth * kernel_height * kernel_width;
906       const size_t output_depth = op->output_depth;
907       const size_t output_height = op->output_height;
908       const size_t output_width = op->output_width;
909 
910       const size_t step_height = op->step_height;
911       const size_t step_width = op->step_width;
912 
913       switch (kernel_size) {
914         case 9: {
915           struct q8dwconv2d_context context = {
916               .groups = groups,
917               .indirection_buffer = (const uint8_t**)op->indirection_buffer,
918               .indirection_buffer_row_stride = step_height,
919               .indirection_buffer_col_stride =
920                   kernel_height * step_width * sizeof(void*),
921               .packed_weights = op->packed_weights,
922               .output = op->output,
923               .output_height = output_height,
924               .output_width = output_width,
925               .output_row_stride = output_width * op->output_pixel_stride,
926               .output_col_increment =
927                   (op->output_pixel_stride - groups) * sizeof(uint8_t),
928               .quantization_params = op->conv_quantization_params,
929               .unipass_ukernel = op->per_channel
930                   ? pytorch_qnnp_params.q8dw9.updw_per_channel
931                   : pytorch_qnnp_params.q8dw9.updw,
932           };
933           pthreadpool_compute_2d(
934               threadpool,
935               (pthreadpool_function_2d_t)compute_dwconv2d_unipass,
936               &context,
937               batch_size,
938               output_height);
939           break;
940         }
941         case 25: {
942           struct q8dwconv2d_context context = {
943               .groups = groups,
944               .group_stride = op->group_stride,
945               .indirection_buffer = (const uint8_t**)op->indirection_buffer,
946               .indirection_buffer_row_stride = step_height,
947               .indirection_buffer_col_stride =
948                   kernel_height * step_width * sizeof(void*),
949               .packed_weights = op->packed_weights,
950               .output = op->output,
951               .output_height = output_height,
952               .output_width = output_width,
953               .output_row_stride = output_width * op->output_pixel_stride,
954               .output_col_increment =
955                   (op->output_pixel_stride - groups) * sizeof(uint8_t),
956               .quantization_params = op->conv_quantization_params,
957               .multipass_ukernel = op->per_channel
958                   ? pytorch_qnnp_params.q8dw25.mpdw_per_channel
959                   : pytorch_qnnp_params.q8dw25.mpdw,
960           };
961           pthreadpool_compute_2d(
962               threadpool,
963               (pthreadpool_function_2d_t)compute_dwconv2d_multiipass,
964               &context,
965               batch_size,
966               output_height);
967           break;
968         }
969         case 27: {
970           struct q8dwconv3d_context context = {
971               .groups = groups,
972               .group_stride = op->group_stride,
973               .indirection_buffer = (const uint8_t**)op->indirection_buffer,
974               .indirection_buffer_slice_stride = step_height * output_height,
975               .indirection_buffer_row_stride = step_height * sizeof(void*),
976               .indirection_buffer_col_stride =
977                   kernel_height * kernel_depth * step_width * sizeof(void*),
978               .packed_weights = op->packed_weights,
979               .output = op->output,
980               .output_depth = output_depth,
981               .output_height = output_height,
982               .output_width = output_width,
983               .output_slice_stride =
984                   output_height * output_width * op->output_pixel_stride,
985               .quantization_params = op->conv_quantization_params,
986               .multipass_ukernel = pytorch_qnnp_params.q8dw27.mpdw,
987           };
988           pthreadpool_compute_2d(
989               threadpool,
990               (pthreadpool_function_2d_t)compute_dwconv3d_multiipass,
991               &context,
992               batch_size,
993               output_depth);
994           break;
995         }
996         default:
997           PYTORCH_QNNP_UNREACHABLE;
998       }
999       break;
1000     }
1001     case pytorch_qnnp_ukernel_type_xzp_gemm: {
1002       const size_t batch_size = op->batch_size;
1003       const size_t groups = op->groups;
1004       const size_t group_input_channels = op->group_input_channels;
1005       const size_t group_output_channels = op->group_output_channels;
1006       const uint32_t mr = pytorch_qnnp_params.q8conv_xzp.mr;
1007       const uint32_t nr = pytorch_qnnp_params.q8conv_xzp.nr;
1008       const uint32_t kr = pytorch_qnnp_params.q8conv_xzp.kr;
1009       const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
1010       const size_t n_stride = (group_output_channels + (nr - 1)) & -nr;
1011 
1012       /* compute input row sum */
1013       const size_t input_size = op->input_height * op->input_width;
1014       int32_t* a_sum = (int32_t*)op->a_sum;
1015 
1016       struct q8sum_rows_context context = {
1017           .a = op->input,
1018           .groups = groups,
1019           .m = input_size,
1020           .k = group_input_channels,
1021           .a_stride = op->input_pixel_stride,
1022           .multiplier = (int32_t)-op->kernel_zero_point,
1023           .a_sum = a_sum,
1024           .a_sum_stride = input_size,
1025           .ukernel = pytorch_qnnp_params.q8sum_rows.sum_rows,
1026       };
1027       pthreadpool_compute_3d_tiled(
1028           threadpool,
1029           (pthreadpool_function_3d_tiled_t)compute_sum_rows,
1030           &context,
1031           groups,
1032           batch_size,
1033           input_size,
1034           1,
1035           1,
1036           pytorch_qnnp_params.q8sum_rows.m);
1037 
1038       struct q8gemm_xzp_context q8gemm_xzp_context = {
1039           .k = group_input_channels,
1040           .k_stride = k_stride,
1041           .n = group_output_channels,
1042           .n_stride = n_stride,
1043           .a = op->input,
1044           .a_stride = op->input_pixel_stride,
1045           .packed_w = op->packed_weights,
1046           .c = op->output,
1047           .c_stride = op->output_pixel_stride,
1048           .a_sum = a_sum,
1049           .groups = op->groups,
1050           .batch_size = batch_size,
1051           .a_sum_stride = input_size,
1052           .requantization_params = op->requantization_params,
1053           .ukernel = pytorch_qnnp_params.q8conv_xzp.gemm,
1054       };
1055       pthreadpool_compute_4d_tiled(
1056           threadpool,
1057           (pthreadpool_function_4d_tiled_t)compute_q8gemm_xzp,
1058           &q8gemm_xzp_context,
1059           groups,
1060           batch_size * input_size,
1061           input_size,
1062           group_output_channels,
1063           1,
1064           input_size,
1065           mr,
1066           nr);
1067       break;
1068     }
1069     case pytorch_qnnp_ukernel_type_gemm: {
1070       const size_t batch_size = op->batch_size;
1071       const size_t groups = op->groups;
1072       const size_t group_input_channels = op->group_input_channels;
1073       const size_t group_output_channels = op->group_output_channels;
1074       const uint32_t mr = pytorch_qnnp_params.q8conv.mr;
1075       const uint32_t nr = pytorch_qnnp_params.q8conv.nr;
1076       const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
1077       const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
1078       const size_t n_stride = (group_output_channels + (nr - 1)) & -nr;
1079       const size_t output_depth = op->output_depth;
1080       const size_t output_size = (output_depth != 0 ? output_depth : 1) *
1081           op->output_height * op->output_width;
1082 
1083       struct q8gemm_context q8gemm_context = {
1084           .k = group_input_channels,
1085           .k_stride = k_stride,
1086           .n = group_output_channels,
1087           .n_stride = n_stride,
1088           .a = op->input,
1089           .a_stride = op->input_pixel_stride,
1090           .packed_w = op->packed_weights,
1091           .c = op->output,
1092           .c_stride = op->output_pixel_stride,
1093           .quantization_params = op->conv_quantization_params,
1094           .ukernel = pytorch_qnnp_params.q8conv.gemm,
1095       };
1096 
1097       pthreadpool_compute_4d_tiled(
1098           threadpool,
1099           (pthreadpool_function_4d_tiled_t)compute_q8gemm,
1100           &q8gemm_context,
1101           groups,
1102           batch_size * output_size,
1103           output_size,
1104           group_output_channels,
1105           1,
1106           output_size,
1107           mr,
1108           nr);
1109       break;
1110     }
1111 #ifdef NO_PREPACK_SPARSE_KERNEL
1112     case pytorch_qnnp_ukernel_type_gemm_sparse_dq: {
1113       const size_t batch_size = op->batch_size;
1114       const size_t groups = op->groups;
1115       const size_t group_output_channels = op->group_output_channels;
1116       const uint32_t mr = pytorch_qnnp_params.q8gemm_sparse_c1x4.mr;
1117       const uint32_t nr = pytorch_qnnp_params.q8gemm_sparse_c1x4.nr;
1118 
1119       const size_t output_size = op->output_height * op->output_width;
1120       struct q8gemm_sparse_dq_context q8gemm_sparse_dq_context = {
1121           .a = op->input,
1122           .a_stride = op->input_pixel_stride,
1123           .kernel_col_indices = op->sparse_matrix.col_indices,
1124           .kernel_row_values = op->sparse_matrix.row_values,
1125           .kernel_values = op->sparse_matrix.values,
1126           .bias = (const float*)op->bias,
1127           .c = (float*)op->output,
1128           .c_stride = op->output_pixel_stride,
1129           .quantization_params = op->dynamic_conv_quantization_params,
1130           .ukernel = pytorch_qnnp_params.q8gemm_sparse_c1x4.gemm_dq,
1131       };
1132 
1133       pthreadpool_compute_4d_tiled(
1134           threadpool,
1135           (pthreadpool_function_4d_tiled_t)compute_q8gemm_sparse_dq,
1136           &q8gemm_sparse_dq_context,
1137           groups,
1138           batch_size * output_size,
1139           output_size,
1140           group_output_channels,
1141           1,
1142           output_size,
1143           mr,
1144           nr);
1145       break;
1146     }
1147 #endif
1148     case pytorch_qnnp_ukernel_type_gemm_prepackA_sparse_dq: {
1149       const size_t batch_size = op->batch_size;
1150       const size_t groups = op->groups;
1151       const size_t group_input_channels = op->group_input_channels;
1152       const size_t group_output_channels = op->group_output_channels;
1153       uint32_t mr, log2_mr, nr, kr, log2_row_block_size;
1154       pytorch_q8gemm_sparse_packA_ukernel_function prepack_kernel;
1155       struct pytorch_q8gemm_sparse_parameters* pytorch_q8gemm_sparse_params =
1156           NULL; // used to assign ukernel
1157       if (op->sparse_matrix.row_block_size == 1 &&
1158           op->sparse_matrix.col_block_size == 4) {
1159         mr = pytorch_qnnp_params.q8gemm_sparse_c1x4.mr;
1160         log2_mr = pytorch_qnnp_params.q8gemm_sparse_c1x4.log2_mr;
1161         log2_row_block_size = 0;
1162         nr = pytorch_qnnp_params.q8gemm_sparse_c1x4.nr;
1163         kr = pytorch_qnnp_params.q8gemm_sparse_c1x4.kr;
1164         prepack_kernel = pytorch_qnnp_params.q8gemm_sparse_c1x4.packA;
1165         pytorch_q8gemm_sparse_params = &pytorch_qnnp_params.q8gemm_sparse_c1x4;
1166       } else if (op->sparse_matrix.row_block_size == 8 &&
1167           op->sparse_matrix.col_block_size == 1) {
1168         mr = pytorch_qnnp_params.q8gemm_sparse_c8x1.mr;
1169         log2_mr = pytorch_qnnp_params.q8gemm_sparse_c8x1.log2_mr;
1170         log2_row_block_size = 3;
1171         nr = pytorch_qnnp_params.q8gemm_sparse_c8x1.nr;
1172         kr = pytorch_qnnp_params.q8gemm_sparse_c8x1.kr;
1173         prepack_kernel = pytorch_qnnp_params.q8gemm_sparse_c8x1.packA;
1174         pytorch_q8gemm_sparse_params = &pytorch_qnnp_params.q8gemm_sparse_c8x1;
1175       } else {
1176         return pytorch_qnnp_status_invalid_parameter;
1177       }
1178 
1179       const size_t output_size = op->output_height * op->output_width;
1180       const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
1181       const size_t m_stride = (output_size + (mr - 1)) & -mr;
1182       op->prepacked_a =
1183         (uint8_t*)realloc((void*)op->prepacked_a, k_stride * m_stride);
1184       if (op->prepacked_a == NULL) {
1185         pytorch_qnnp_log_error(
1186             "failed to allocate %zu bytes for packed activation buffer",
1187             (k_stride * m_stride));
1188         return pytorch_qnnp_status_out_of_memory;
1189       }
1190 
1191       struct q8gemm_prepackA_sparse_dq_context
1192           q8gemm_prepack_sparse_dq_context = {
1193               .k = group_input_channels,
1194               .a = op->input,
1195               .a_stride = op->input_pixel_stride,
1196               .a_packed = op->prepacked_a,
1197               .a_packed_stride = k_stride * mr,
1198               .log2_mr = log2_mr,
1199               .log2_row_block_size = log2_row_block_size,
1200               .kernel_indices_dtype = op->sparse_matrix.indices_dtype,
1201               .kernel_values = op->sparse_matrix.values,
1202               .bias = (const float*)op->bias,
1203               .c = (float*)op->output,
1204               .c_stride = op->output_pixel_stride,
1205               .quantization_params = op->dynamic_conv_quantization_params,
1206               .prepack_ukernel = prepack_kernel,
1207               // kernel_col_indices, kernel_row_values, and ukernel assigned
1208               // below
1209           };
1210 
1211       switch (q8gemm_prepack_sparse_dq_context.kernel_indices_dtype) {
1212         case pytorch_qnnp_sparse_matrix_indices_dtype_uint32_t:
1213           q8gemm_prepack_sparse_dq_context.kernel_col_indices_w32 =
1214               op->sparse_matrix.col_indices_w32;
1215           q8gemm_prepack_sparse_dq_context.kernel_row_values_w32 =
1216               op->sparse_matrix.row_values_w32;
1217           q8gemm_prepack_sparse_dq_context.ukernel_w32 =
1218               pytorch_q8gemm_sparse_params->packedA_w32_gemm_dq;
1219           break;
1220         case pytorch_qnnp_sparse_matrix_indices_dtype_uint16_t:
1221           q8gemm_prepack_sparse_dq_context.kernel_col_indices_w16 =
1222               op->sparse_matrix.col_indices_w16;
1223           q8gemm_prepack_sparse_dq_context.kernel_row_values_w16 =
1224               op->sparse_matrix.row_values_w16;
1225           q8gemm_prepack_sparse_dq_context.ukernel_w16 =
1226               pytorch_q8gemm_sparse_params->packedA_w16_gemm_dq;
1227           break;
1228         case pytorch_qnnp_sparse_matrix_indices_dtype_uint8_t:
1229           q8gemm_prepack_sparse_dq_context.kernel_col_indices_w8 =
1230               op->sparse_matrix.col_indices_w8;
1231           q8gemm_prepack_sparse_dq_context.kernel_row_values_w8 =
1232               op->sparse_matrix.row_values_w8;
1233           q8gemm_prepack_sparse_dq_context.ukernel_w8 =
1234               pytorch_q8gemm_sparse_params->packedA_w8_gemm_dq;
1235           break;
1236         case pytorch_qnnp_sparse_matrix_indices_dtype_invalid:
1237           // Catch invalid index type and return early.
1238           // This ensures all subsequent calls will have a valid index type.
1239           pytorch_qnnp_log_error(
1240               "Invalid indices dtype specified for "
1241               "operator-run pytorch_qnnp_ukernel_type_gemm_prepackA_sparse_dq");
1242           return pytorch_qnnp_status_invalid_parameter;
1243       }
1244 
1245       // This batch size is not the actual batch size of the op
1246       // The batch size is modified in fully-connected-sparse.c
1247       if (groups != 1 || batch_size != 1) {
1248         pytorch_qnnp_log_error("pytorch_qnnp_ukernel_type_gemm_prepackA_sparse_dq "
1249             "works with group size = 1, batch_size = 1.\n");
1250         return pytorch_qnnp_status_invalid_parameter;
1251       }
1252 
1253       pthreadpool_compute_4d_tiled(
1254           threadpool,
1255           (pthreadpool_function_4d_tiled_t)compute_q8gemm_prepack_a_sparse,
1256           &q8gemm_prepack_sparse_dq_context,
1257           1,
1258           1,
1259           output_size,
1260           1,
1261           1,
1262           1,
1263           mr,
1264           1);
1265 
1266       pthreadpool_compute_4d_tiled(
1267           threadpool,
1268           (pthreadpool_function_4d_tiled_t)compute_q8gemm_prepacked_sparse_dq,
1269           &q8gemm_prepack_sparse_dq_context,
1270           groups,
1271           batch_size * output_size,
1272           output_size,
1273           group_output_channels,
1274           1,
1275           output_size,
1276           mr,
1277           nr);
1278       break;
1279     }
1280     case pytorch_qnnp_ukernel_type_conv: {
1281       const size_t batch_size = op->batch_size;
1282       const size_t groups = op->groups;
1283       const size_t group_input_channels = op->group_input_channels;
1284       const size_t group_output_channels = op->group_output_channels;
1285       const uint32_t mr = pytorch_qnnp_params.q8conv.mr;
1286       const uint32_t nr = pytorch_qnnp_params.q8conv.nr;
1287       const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
1288       const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
1289       const size_t n_stride = (group_output_channels + (nr - 1)) & -nr;
1290       const size_t output_depth = op->output_depth;
1291       const size_t output_size = (output_depth != 0 ? output_depth : 1) *
1292           op->output_height * op->output_width;
1293       const size_t kernel_depth = op->kernel_depth;
1294       const size_t kernel_size = (kernel_depth != 0 ? kernel_depth : 1) *
1295           op->kernel_height * op->kernel_width;
1296       const size_t m_stride = round_up(output_size, mr);
1297 
1298       struct q8conv_context q8conv_context = {
1299           .bs = batch_size,
1300           .ks = kernel_size,
1301           .kc = group_input_channels,
1302           .kc_stride = k_stride * kernel_size,
1303           .m = output_size,
1304           .m_stride = m_stride,
1305           .n = group_output_channels,
1306           .n_stride = n_stride,
1307           .indirect_a = (const uint8_t**)op->indirection_buffer,
1308           .packed_w = op->packed_weights,
1309           .c = op->output,
1310           .c_stride = op->output_pixel_stride,
1311           .quantization_params = op->conv_quantization_params,
1312           .ukernel = pytorch_qnnp_params.q8conv.conv,
1313       };
1314 
1315       pthreadpool_compute_4d_tiled(
1316           threadpool,
1317           (pthreadpool_function_4d_tiled_t)compute_q8conv,
1318           &q8conv_context,
1319           groups,
1320           batch_size,
1321           output_size,
1322           group_output_channels,
1323           1,
1324           1,
1325           mr,
1326           nr);
1327       break;
1328     }
1329     case pytorch_qnnp_ukernel_type_average_pooling: {
1330       const uint32_t kr = pytorch_qnnp_params.q8avgpool.kr;
1331       const uint32_t mr = pytorch_qnnp_params.q8avgpool.mr;
1332       const uint32_t qr = pytorch_qnnp_params.q8avgpool.qr;
1333       const size_t channels = op->channels;
1334       const size_t output_width = op->output_width;
1335       const size_t output_height = op->output_height;
1336       const size_t pooling_height = op->kernel_height;
1337       const size_t pooling_width = op->kernel_width;
1338       const size_t pooling_size = pooling_height * pooling_width;
1339 
1340       const size_t indirect_input_height_stride =
1341           op->step_height * sizeof(void*);
1342       const size_t output_height_stride =
1343           output_width * op->output_pixel_stride;
1344 
1345       size_t multipass_adjustment = 0;
1346       if (channels >= kr && pooling_size > mr) {
1347         multipass_adjustment = round_up(pooling_size - mr, qr) + mr - qr;
1348       }
1349       struct average_pooling_context context = {
1350           .indirect_input = op->indirection_buffer,
1351           .indirect_input_batch_stride =
1352               output_height * indirect_input_height_stride,
1353           .indirect_input_height_stride = indirect_input_height_stride,
1354           .output = op->output,
1355           .output_batch_stride = output_height * output_height_stride,
1356           .output_height_stride = output_height_stride,
1357           .output_width = output_width,
1358           .pooling_size = pooling_size,
1359           .channels = channels,
1360           .packed_channels = (channels + (kr - 1)) & -kr,
1361           .zero = op->zero_pointer,
1362           .input_increment =
1363               (pooling_height * op->step_width - multipass_adjustment) *
1364               sizeof(void*),
1365           .output_increment =
1366               (op->output_pixel_stride - channels) * sizeof(uint8_t),
1367           .quantization_params = op->avgpool_quantization_params,
1368       };
1369 
1370       pthreadpool_function_2d_t compute_function = NULL;
1371       if (channels < kr) {
1372         compute_function =
1373             (pthreadpool_function_2d_t)compute_average_pooling_unipass;
1374         context.unipass_ukernel = pytorch_qnnp_params.q8avgpool.ltkr;
1375       } else {
1376         if (pooling_size <= mr) {
1377           compute_function =
1378               (pthreadpool_function_2d_t)compute_average_pooling_unipass;
1379           context.unipass_ukernel = pytorch_qnnp_params.q8avgpool.gekr_lemr;
1380         } else {
1381           compute_function =
1382               (pthreadpool_function_2d_t)compute_average_pooling_multipass;
1383           context.multipass_ukernel = pytorch_qnnp_params.q8avgpool.gekr_gtmr;
1384         }
1385       }
1386 
1387       pthreadpool_compute_2d(
1388           threadpool,
1389           compute_function,
1390           &context,
1391           op->batch_size,
1392           output_height);
1393       break;
1394     }
1395     case pytorch_qnnp_ukernel_type_max_pooling: {
1396       const uint32_t kr = pytorch_qnnp_params.u8maxpool.kr;
1397       const uint32_t mr = pytorch_qnnp_params.u8maxpool.mr;
1398       const uint32_t qr = pytorch_qnnp_params.u8maxpool.qr;
1399       const size_t channels = op->channels;
1400       const size_t output_width = op->output_width;
1401       const size_t output_height = op->output_height;
1402       const size_t pooling_height = op->kernel_height;
1403       const size_t pooling_width = op->kernel_width;
1404       const size_t pooling_size = pooling_height * pooling_width;
1405 
1406       const size_t indirect_input_height_stride =
1407           op->step_height * sizeof(void*);
1408       const size_t output_height_stride =
1409           output_width * op->output_pixel_stride;
1410 
1411       size_t multipass_adjustment = pooling_size;
1412       if (channels >= kr) {
1413         multipass_adjustment = round_up(doz(pooling_size, mr), qr) + mr;
1414       }
1415       struct max_pooling_context context = {
1416           .indirect_input = op->indirection_buffer,
1417           .indirect_input_batch_stride =
1418               output_height * indirect_input_height_stride,
1419           .indirect_input_height_stride = indirect_input_height_stride,
1420           .output = op->output,
1421           .output_batch_stride = output_height * output_height_stride,
1422           .output_height_stride = output_height_stride,
1423           .output_width = output_width,
1424           .pooling_size = pooling_size,
1425           .channels = channels,
1426           .input_increment =
1427               (pooling_height * op->step_width - multipass_adjustment) *
1428               sizeof(void*),
1429           .output_increment =
1430               (op->output_pixel_stride - channels) * sizeof(uint8_t),
1431           .params = op->u8_clamping_params,
1432           .ukernel = channels < kr ? pytorch_qnnp_params.u8maxpool.ltkr
1433                                    : pytorch_qnnp_params.u8maxpool.gekr,
1434       };
1435 
1436       pthreadpool_compute_2d(
1437           threadpool,
1438           (pthreadpool_function_2d_t)compute_max_pooling,
1439           &context,
1440           op->batch_size,
1441           output_height);
1442       break;
1443     };
1444     case pytorch_qnnp_ukernel_type_add: {
1445       const size_t batch_size = op->batch_size;
1446       const size_t channels = op->channels;
1447       const size_t a_stride = op->input_pixel_stride;
1448       const size_t b_stride = op->input2_pixel_stride;
1449       const size_t y_stride = op->output_pixel_stride;
1450       if ((((a_stride ^ channels) | (b_stride ^ channels) |
1451             (y_stride ^ channels)) == 0) ||
1452           batch_size == 1) {
1453         const size_t block_size = 4096;
1454         struct q8add_contiguous_context add_context = {
1455             .a = op->input,
1456             .b = op->input2,
1457             .y = op->output,
1458             .quantization_params = op->add_quantization_params,
1459             .ukernel = pytorch_qnnp_params.q8vadd,
1460         };
1461         pthreadpool_compute_1d_tiled(
1462             threadpool,
1463             (pthreadpool_function_1d_tiled_t)compute_q8add_contiguous,
1464             &add_context,
1465             batch_size * channels * sizeof(uint8_t),
1466             block_size);
1467       } else {
1468         struct q8add_strided_context add_context = {
1469             .a = op->input,
1470             .a_stride = a_stride * sizeof(uint8_t),
1471             .b = op->input2,
1472             .b_stride = b_stride * sizeof(uint8_t),
1473             .y = op->output,
1474             .y_stride = y_stride * sizeof(uint8_t),
1475             .n = channels,
1476             .quantization_params = op->add_quantization_params,
1477             .ukernel = pytorch_qnnp_params.q8vadd,
1478         };
1479         pthreadpool_compute_1d_tiled(
1480             threadpool,
1481             (pthreadpool_function_1d_tiled_t)compute_q8add_strided,
1482             &add_context,
1483             batch_size,
1484             1);
1485       }
1486       break;
1487     }
1488     case pytorch_qnnp_ukernel_type_global_average_pooling: {
1489       const uint32_t nr = pytorch_qnnp_params.q8gavgpool.nr;
1490       const uint32_t mr = pytorch_qnnp_params.q8gavgpool.mr;
1491       const size_t input_pixel_stride =
1492           op->input_pixel_stride * sizeof(uint8_t);
1493       const size_t input_width = op->input_width;
1494       const size_t channels = op->channels;
1495       struct global_average_pooling_context context = {
1496           .input = op->input,
1497           .zero = op->zero_pointer,
1498           .input_pixel_stride = input_pixel_stride,
1499           .input_batch_stride = input_pixel_stride * input_width,
1500           .input_elements = input_width,
1501           .channels = channels,
1502           .packed_channels = (channels + (nr - 1)) & -nr,
1503           .output = op->output,
1504           .output_batch_stride = op->output_pixel_stride * sizeof(uint8_t),
1505           .quantization_params = op->avgpool_quantization_params,
1506       };
1507       pthreadpool_function_1d_t compute_function = NULL;
1508       if (channels < nr) {
1509         compute_function =
1510             (pthreadpool_function_1d_t)compute_global_average_pooling_unipass;
1511         context.unipass_ukernel = pytorch_qnnp_params.q8gavgpool.ltnr;
1512       } else {
1513         if (input_width <= mr) {
1514           compute_function =
1515               (pthreadpool_function_1d_t)compute_global_average_pooling_unipass;
1516           context.unipass_ukernel = pytorch_qnnp_params.q8gavgpool.genr_lemr;
1517         } else {
1518           compute_function = (pthreadpool_function_1d_t)
1519               compute_global_average_pooling_multipass;
1520           context.multipass_ukernel = pytorch_qnnp_params.q8gavgpool.genr_gtmr;
1521         }
1522       }
1523 
1524       pthreadpool_compute_1d(
1525           threadpool, compute_function, &context, op->batch_size);
1526       break;
1527     }
1528     case pytorch_qnnp_ukernel_type_lut: {
1529       const size_t batch_size = op->batch_size;
1530       const size_t channels = op->channels;
1531       const size_t x_stride = op->input_pixel_stride;
1532       const size_t y_stride = op->output_pixel_stride;
1533       if ((((x_stride ^ channels) | (y_stride ^ channels)) == 0) ||
1534           batch_size == 1) {
1535         const size_t block_size = 1024;
1536         struct lut_contiguous_context context = {
1537             .x = op->input,
1538             .x_stride = x_stride * sizeof(uint8_t),
1539             .t = op->lookup_table,
1540             .y = op->output,
1541             .y_stride = y_stride * sizeof(uint8_t),
1542             .ukernel = pytorch_qnnp_params.x8lut,
1543         };
1544         pthreadpool_compute_1d_tiled(
1545             threadpool,
1546             (pthreadpool_function_1d_tiled_t)compute_lut_contiguous,
1547             &context,
1548             batch_size * channels * sizeof(uint8_t),
1549             block_size);
1550       } else {
1551         struct lut_strided_context context = {
1552             .n = channels,
1553             .x = op->input,
1554             .x_stride = x_stride * sizeof(uint8_t),
1555             .t = op->lookup_table,
1556             .y = op->output,
1557             .y_stride = y_stride * sizeof(uint8_t),
1558             .ukernel = pytorch_qnnp_params.x8lut,
1559         };
1560         pthreadpool_compute_1d(
1561             threadpool,
1562             (pthreadpool_function_1d_t)compute_lut_strided,
1563             &context,
1564             batch_size);
1565       }
1566       break;
1567     }
1568     case pytorch_qnnp_ukernel_type_clamp: {
1569       const size_t batch_size = op->batch_size;
1570       const size_t channels = op->channels;
1571       const size_t x_stride = op->input_pixel_stride;
1572       const size_t y_stride = op->output_pixel_stride;
1573       if ((((x_stride ^ channels) | (y_stride ^ channels)) == 0) ||
1574           batch_size == 1) {
1575         const size_t block_size = 4096;
1576         struct clamp_contiguous_context context = {
1577             .x = op->input,
1578             .x_stride = x_stride * sizeof(uint8_t),
1579             .y = op->output,
1580             .y_stride = y_stride * sizeof(uint8_t),
1581             .ukernel = pytorch_qnnp_params.u8clamp,
1582             .params = op->u8_clamping_params,
1583         };
1584         pthreadpool_compute_1d_tiled(
1585             threadpool,
1586             (pthreadpool_function_1d_tiled_t)compute_clamp_contiguous,
1587             &context,
1588             batch_size * channels * sizeof(uint8_t),
1589             block_size);
1590       } else {
1591         struct clamp_strided_context context = {
1592             .n = channels,
1593             .x = op->input,
1594             .x_stride = x_stride * sizeof(uint8_t),
1595             .y = op->output,
1596             .y_stride = y_stride * sizeof(uint8_t),
1597             .ukernel = pytorch_qnnp_params.u8clamp,
1598             .params = op->u8_clamping_params,
1599         };
1600         pthreadpool_compute_1d(
1601             threadpool,
1602             (pthreadpool_function_1d_t)compute_clamp_strided,
1603             &context,
1604             batch_size);
1605       }
1606       break;
1607     }
1608     case pytorch_qnnp_ukernel_type_softargmax: {
1609       struct u8softargmax_context context = {
1610           .n = op->channels,
1611           .x = op->input,
1612           .x_stride = op->input_pixel_stride * sizeof(uint8_t),
1613           .t = op->lookup_table,
1614           .y = op->output,
1615           .y_stride = op->output_pixel_stride * sizeof(uint8_t),
1616           .rmax_ukernel = pytorch_qnnp_params.u8rmax,
1617           .lut_norm_ukernel = pytorch_qnnp_params.u8lut32norm,
1618       };
1619       pthreadpool_compute_1d(
1620           threadpool,
1621           (pthreadpool_function_1d_t)compute_u8softargmax,
1622           &context,
1623           op->batch_size);
1624       break;
1625     }
1626     case pytorch_qnnp_ukernel_type_channel_shuffle: {
1627       const size_t groups = op->groups;
1628       struct channel_shuffle_context channel_shuffle_context = {
1629           .x = op->input,
1630           .x_stride = op->input_pixel_stride * sizeof(uint8_t),
1631           .y = op->output,
1632           .y_stride = op->output_pixel_stride * sizeof(uint8_t),
1633           .n = op->group_channels * sizeof(uint8_t),
1634           .m = groups,
1635       };
1636       pthreadpool_function_1d_t compute_function = NULL;
1637       switch (groups) {
1638         case 2:
1639           compute_function =
1640               (pthreadpool_function_1d_t)compute_channel_shuffle_fixed;
1641           channel_shuffle_context.fixed_ukernel = pytorch_qnnp_params.x8zip.x2;
1642           break;
1643         case 3:
1644           compute_function =
1645               (pthreadpool_function_1d_t)compute_channel_shuffle_fixed;
1646           channel_shuffle_context.fixed_ukernel = pytorch_qnnp_params.x8zip.x3;
1647           break;
1648         case 4:
1649           compute_function =
1650               (pthreadpool_function_1d_t)compute_channel_shuffle_fixed;
1651           channel_shuffle_context.fixed_ukernel = pytorch_qnnp_params.x8zip.x4;
1652           break;
1653         default:
1654           compute_function =
1655               (pthreadpool_function_1d_t)compute_channel_shuffle_variable;
1656           channel_shuffle_context.variable_ukernel =
1657               pytorch_qnnp_params.x8zip.xm;
1658           break;
1659         case 0:
1660         case 1:
1661           PYTORCH_QNNP_UNREACHABLE;
1662       }
1663       pthreadpool_compute_1d(
1664           threadpool,
1665           compute_function,
1666           &channel_shuffle_context,
1667           op->batch_size);
1668       break;
1669     }
1670     default:
1671       PYTORCH_QNNP_UNREACHABLE;
1672   }
1673   return pytorch_qnnp_status_success;
1674 }
1675