xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/conv-run.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <qnnpack/indirection.h>
2 #include <qnnpack/log.h>
3 #include <qnnpack/operator.h>
4 #include <qnnpack/pack.h>
5 #include <qnnpack_func.h>
6 #include <cstring>
7 #include <memory>
8 #include <numeric>
9 
10 namespace qnnpack {
11 
12 struct q8gemm_xzp_context {
13   size_t k;
14   size_t k_stride;
15   size_t n;
16   size_t n_stride;
17   const uint8_t* a;
18   size_t a_stride;
19   const void* packed_w;
20   uint8_t* c;
21   size_t c_stride;
22   const int32_t* a_sum;
23   size_t groups;
24   size_t batch_size;
25   size_t a_sum_stride;
26   union pytorch_qnnp_q31_requantization_params requantization_params;
27   const pytorch_q8gemm_xzp_ukernel_function ukernel;
28 };
compute_q8gemm_xzp(const struct q8gemm_xzp_context context[1],size_t group_index,size_t pixel_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t pixel_range,size_t mr_block_size,size_t nr_block_size)29 static void compute_q8gemm_xzp(
30     const struct q8gemm_xzp_context context[1],
31     size_t group_index,
32     size_t pixel_index,
33     size_t mr_block_start,
34     size_t nr_block_start,
35     size_t group_range /* always 1 */,
36     size_t pixel_range,
37     size_t mr_block_size,
38     size_t nr_block_size) {
39   const size_t k = context->k;
40   const size_t k_stride = context->k_stride;
41   const size_t n = context->n;
42   const size_t n_stride = context->n_stride;
43   const uint8_t* a = context->a;
44   const size_t a_stride = context->a_stride;
45   const void* packed_w = context->packed_w;
46   uint8_t* c = context->c;
47   const size_t c_stride = context->c_stride;
48   const int32_t* a_sum = context->a_sum;
49   const size_t groups = context->groups;
50   const size_t a_sum_stride = context->a_sum_stride;
51 
52   context->ukernel(
53       mr_block_size,
54       nr_block_size,
55       k,
56       a + (pixel_index + mr_block_start) * a_stride + group_index * k,
57       a_stride,
58       a_sum + pixel_index * groups + group_index * a_sum_stride +
59           mr_block_start,
60       (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (k_stride * sizeof(uint8_t) + sizeof(int32_t))),
61       c + (pixel_index + mr_block_start) * c_stride + nr_block_start +
62           group_index * n,
63       c_stride,
64       &context->requantization_params);
65 }
66 
67 struct q8gemm_context {
68   size_t k;
69   size_t k_stride;
70   size_t n;
71   size_t n_stride;
72   const uint8_t* a;
73   size_t a_stride;
74   const uint8_t* packed_w;
75   uint8_t* c;
76   size_t c_stride;
77   union pytorch_qnnp_conv_quantization_params quantization_params;
78   const pytorch_q8gemm_ukernel_function ukernel;
79 };
compute_q8gemm(const struct q8gemm_context context[1],size_t group_index,size_t pixel_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t pixel_range,size_t mr_block_size,size_t nr_block_size)80 static void compute_q8gemm(
81     const struct q8gemm_context context[1],
82     size_t group_index,
83     size_t pixel_index,
84     size_t mr_block_start,
85     size_t nr_block_start,
86     size_t group_range /* always 1 */,
87     size_t pixel_range,
88     size_t mr_block_size,
89     size_t nr_block_size) {
90   const size_t k = context->k;
91   const size_t k_stride = context->k_stride;
92   const size_t n = context->n;
93   const size_t n_stride = context->n_stride;
94   const uint8_t* a = context->a;
95   const size_t a_stride = context->a_stride;
96   const void* packed_w = context->packed_w;
97   uint8_t* c = context->c;
98   const size_t c_stride = context->c_stride;
99 
100   const size_t output_channel_index = nr_block_start + group_index * n;
101   context->ukernel(
102       mr_block_size,
103       nr_block_size,
104       k,
105       a + (pixel_index + mr_block_start) * a_stride + group_index * k,
106       a_stride,
107       (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (k_stride * sizeof(uint8_t) + sizeof(int32_t))),
108       c + (pixel_index + mr_block_start) * c_stride + nr_block_start +
109           group_index * n,
110       c_stride,
111       output_channel_index,
112       &context->quantization_params);
113 }
114 
115 struct q8conv_context {
116   size_t bs;
117   size_t ks;
118   size_t kc;
119   size_t kc_stride;
120   size_t m;
121   size_t m_stride;
122   size_t n;
123   size_t n_stride;
124   const uint8_t** indirect_a;
125   const void* packed_w;
126   uint8_t* c;
127   size_t c_stride;
128   union pytorch_qnnp_conv_quantization_params quantization_params;
129   const pytorch_q8conv_ukernel_function ukernel;
130 };
compute_q8conv(const struct q8conv_context context[1],size_t group_index,size_t image_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t image_range,size_t mr_block_size,size_t nr_block_size)131 static void compute_q8conv(
132     const struct q8conv_context context[1],
133     size_t group_index,
134     size_t image_index,
135     size_t mr_block_start,
136     size_t nr_block_start,
137     size_t group_range /* always 1 */,
138     size_t image_range /* always 1 */,
139     size_t mr_block_size,
140     size_t nr_block_size) {
141   const size_t bs = context->bs;
142   const size_t ks = context->ks;
143   const size_t kc = context->kc;
144   const size_t kc_stride = context->kc_stride;
145   const size_t m = context->m;
146   const size_t m_stride = context->m_stride;
147   const size_t n = context->n;
148   const size_t n_stride = context->n_stride;
149   const uint8_t** indirect_a = context->indirect_a;
150   const void* packed_w = context->packed_w;
151   uint8_t* c = context->c;
152   const size_t c_stride = context->c_stride;
153 
154   const size_t output_channel_index = group_index * n + nr_block_start;
155   context->ukernel(
156       mr_block_size,
157       nr_block_size,
158       kc,
159       ks,
160       indirect_a +
161           (mr_block_start + (image_index + group_index * bs) * m_stride) * ks,
162       (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (kc_stride * sizeof(uint8_t) + sizeof(int32_t))),
163       c + (mr_block_start + image_index * m) * c_stride + group_index * n +
164           nr_block_start,
165       c_stride,
166       output_channel_index,
167       &context->quantization_params);
168 }
169 
170 struct q8sum_rows_context {
171   const uint8_t* a;
172   size_t groups;
173   size_t m;
174   size_t k;
175   size_t a_stride;
176   const int32_t multiplier;
177   int32_t* a_sum;
178   size_t a_sum_stride;
179   const pytorch_q8sum_rows_ukernel_function ukernel;
180 };
compute_sum_rows(const struct q8sum_rows_context context[1],size_t group_index,size_t batch_index,size_t block_start,size_t group_range,size_t batch_range,size_t block_size)181 static void compute_sum_rows(
182     const struct q8sum_rows_context context[1],
183     size_t group_index,
184     size_t batch_index,
185     size_t block_start,
186     size_t group_range /* always 1 */,
187     size_t batch_range /* always 1 */,
188     size_t block_size) {
189   const uint8_t* a = context->a;
190   const size_t groups = context->groups;
191   const size_t m = context->m;
192   const size_t k = context->k;
193   const size_t a_stride = context->a_stride;
194   const int32_t multiplier = context->multiplier;
195   int32_t* a_sum = context->a_sum;
196   const size_t a_sum_stride = context->a_sum_stride;
197 
198   context->ukernel(
199       a + batch_index * m * a_stride + group_index * k + block_start * a_stride,
200       min(block_size, m - block_start),
201       k,
202       a_stride,
203       multiplier,
204       a_sum + batch_index * groups * a_sum_stride + group_index * a_sum_stride +
205           block_start);
206 }
207 
208 struct q8dwconv2d_context {
209   size_t groups;
210   size_t group_stride;
211   const uint8_t** indirection_buffer;
212   size_t indirection_buffer_row_stride;
213   size_t indirection_buffer_col_stride;
214   const void* packed_weights;
215   uint8_t* output;
216   size_t output_height;
217   size_t output_width;
218   size_t output_row_stride;
219   size_t output_col_increment;
220   union pytorch_qnnp_conv_quantization_params quantization_params;
221   const pytorch_q8dwconv2d_up_ukernel_function unipass_ukernel;
222   const pytorch_q8dwconv2d_mp_ukernel_function multipass_ukernel;
223 };
224 
225 struct q8dwconv3d_context {
226   size_t groups;
227   size_t group_stride;
228   const uint8_t** indirection_buffer;
229   size_t indirection_buffer_slice_stride;
230   size_t indirection_buffer_row_stride;
231   size_t indirection_buffer_col_stride;
232   const void* packed_weights;
233   uint8_t* output;
234   size_t output_depth;
235   size_t output_height;
236   size_t output_width;
237   size_t output_slice_stride;
238   union pytorch_qnnp_conv_quantization_params quantization_params;
239   const pytorch_q8dwconv3d_mp_ukernel_function multipass_ukernel;
240 };
241 
compute_dwconv2d_unipass(const struct q8dwconv2d_context context[1],size_t image,size_t output_y)242 static void compute_dwconv2d_unipass(
243     const struct q8dwconv2d_context context[1],
244     size_t image,
245     size_t output_y) {
246   const size_t output_height = context->output_height;
247 
248   context->unipass_ukernel(
249       context->groups,
250       context->output_width,
251       context->indirection_buffer +
252           (image * output_height + output_y) *
253               context->indirection_buffer_row_stride,
254       context->packed_weights,
255       context->output +
256           (image * output_height + output_y) * context->output_row_stride,
257       context->indirection_buffer_col_stride,
258       context->output_col_increment,
259       &context->quantization_params);
260 }
compute_dwconv2d_multiipass(const struct q8dwconv2d_context context[1],size_t image,size_t output_y)261 static void compute_dwconv2d_multiipass(
262     const struct q8dwconv2d_context context[1],
263     size_t image,
264     size_t output_y) {
265   const size_t output_height = context->output_height;
266   PYTORCH_QNNP_ALIGN(16)
267 #ifdef _MSC_VER
268   int32_t* multipass_acc = (int32_t*)_malloca(sizeof(int32_t) * context->group_stride);
269 #else
270   int32_t multipass_acc[context->group_stride];
271 #endif
272 
273   context->multipass_ukernel(
274       context->groups,
275       context->output_width,
276       context->indirection_buffer +
277           (image * output_height + output_y) *
278               context->indirection_buffer_row_stride,
279       context->packed_weights,
280       multipass_acc,
281       context->output +
282           (image * output_height + output_y) * context->output_row_stride,
283       context->indirection_buffer_col_stride,
284       context->output_col_increment,
285       &context->quantization_params);
286 
287 #ifdef _MSC_VER
288   _freea(multipass_acc);
289 #endif
290 }
291 
compute_dwconv3d_multiipass(const struct q8dwconv3d_context context[1],size_t image,size_t output_z)292 static void compute_dwconv3d_multiipass(
293     const struct q8dwconv3d_context context[1],
294     size_t image,
295     size_t output_z) {
296   const size_t output_depth = context->output_depth;
297   PYTORCH_QNNP_ALIGN(16)
298 #ifdef _MSC_VER
299   int32_t* multipass_acc =
300       (int32_t*)_malloca(sizeof(int32_t) * context->group_stride);
301 #else
302   int32_t multipass_acc[context->group_stride];
303 #endif
304 
305   context->multipass_ukernel(
306       context->groups,
307       context->output_height,
308       context->output_width,
309       context->indirection_buffer +
310           (image * output_depth + output_z) *
311               context->indirection_buffer_slice_stride,
312       context->packed_weights,
313       multipass_acc,
314       context->output +
315           (image * output_depth + output_z) * context->output_slice_stride,
316       context->indirection_buffer_row_stride,
317       context->indirection_buffer_col_stride,
318       0,
319       &context->quantization_params);
320 
321 #ifdef _MSC_VER
322   _freea(multipass_acc);
323 #endif
324 }
325 
326 struct QnnpackDeleter {
operator ()qnnpack::QnnpackDeleter327   void operator()(pytorch_qnnp_operator_t op) {
328     pytorch_qnnp_delete_operator(op);
329   }
330 };
331 
qnnpackConv(const pytorch_qnnp_operator_t convolution,void * packed_weights,const size_t batch_size,const size_t input_depth,const size_t input_height,const size_t input_width,const uint8_t input_zero_point,const uint8_t * input,const uint8_t * kernel_zero_points,const float * requantization_scales,const uint8_t output_zero_point,const uint8_t output_min,const uint8_t output_max,uint8_t * output,pthreadpool_t threadpool)332 enum pytorch_qnnp_status qnnpackConv(
333     const pytorch_qnnp_operator_t convolution,
334     void* packed_weights,
335     const size_t batch_size,
336     const size_t input_depth,
337     const size_t input_height,
338     const size_t input_width,
339     const uint8_t input_zero_point,
340     const uint8_t* input,
341     const uint8_t* kernel_zero_points,
342     const float* requantization_scales,
343     const uint8_t output_zero_point,
344     const uint8_t output_min,
345     const uint8_t output_max,
346     uint8_t* output,
347     pthreadpool_t threadpool) {
348   const size_t groups = convolution->groups;
349   const size_t input_pixel_stride = convolution->group_input_channels * groups;
350   const size_t output_pixel_stride =
351       convolution->group_output_channels * groups;
352   const size_t kernel_width = convolution->kernel_width;
353   const size_t kernel_height = convolution->kernel_height;
354   const size_t kernel_depth = convolution->kernel_depth;
355   const size_t kernel_size = kernel_height * kernel_width * kernel_depth;
356 
357   if (batch_size == 0) {
358     // If no batches, return
359     return pytorch_qnnp_status_success;
360   }
361 
362   union pytorch_qnnp_q31_requantization_params requantization_params {};
363   union pytorch_qnnp_conv_quantization_params conv_quantization_params {};
364   if (convolution->ukernel_type == pytorch_qnnp_ukernel_type_xzp_gemm) {
365     requantization_params = pytorch_qnnp_compute_requantization_params(
366         // Note. XZP kernels are not changed for per channel quant.
367         requantization_scales[0],
368         output_zero_point,
369         output_min,
370         output_max);
371   } else {
372     conv_quantization_params = pytorch_qnnp_compute_conv_quantization_params(
373         input_zero_point,
374         kernel_zero_points,
375         requantization_scales,
376         output_zero_point,
377         output_min,
378         output_max);
379   }
380 
381   // Convolution op caches a few things.
382   // We need to check if the corresponding values on this
383   // invocation is same as cached values.
384   // If so we can skip setup step.
385   if (convolution->input != input || convolution->batch_size != batch_size ||
386       convolution->input_depth != input_depth ||
387       convolution->input_height != input_height ||
388       convolution->input_width != input_width ||
389       convolution->input_pixel_stride != input_pixel_stride) {
390     pytorch_qnnp_status status = pytorch_qnnp_setup_convolution_ndhwc_q8(
391         convolution,
392         batch_size,
393         input_depth,
394         input_height,
395         input_width,
396         input,
397         input_pixel_stride,
398         output,
399         output_pixel_stride,
400         threadpool);
401     if (status != pytorch_qnnp_status_success) {
402       pytorch_qnnp_log_error(
403           "failed to run convolution op setup to setup indirection buffer.");
404       return status;
405     }
406   }
407 
408   const size_t output_size = convolution->output_height *
409       convolution->output_width * convolution->output_depth;
410 
411   switch (convolution->ukernel_type) {
412     case pytorch_qnnp_ukernel_type_dwconv: {
413       const uint32_t cr = pytorch_qnnp_params.q8dw9.cr;
414       const size_t group_stride = (groups + (cr - 1)) & -cr;
415 
416       const size_t step_height = convolution->step_height;
417       const size_t step_width = convolution->step_width;
418 
419       switch (kernel_size) {
420         case 9: {
421           struct q8dwconv2d_context context = {
422               .groups = groups,
423               .group_stride = group_stride,
424               .indirection_buffer =
425                   (const uint8_t**)convolution->indirection_buffer,
426               .indirection_buffer_row_stride = step_height,
427               .indirection_buffer_col_stride =
428                   kernel_height * step_width * sizeof(void*),
429               .packed_weights = packed_weights,
430               .output = output,
431               .output_height = convolution->output_height,
432               .output_width = convolution->output_width,
433               .output_row_stride =
434                   convolution->output_width * output_pixel_stride,
435               .output_col_increment =
436                   (output_pixel_stride - groups) * sizeof(uint8_t),
437               .quantization_params = conv_quantization_params,
438               .unipass_ukernel = convolution->per_channel
439                   ? pytorch_qnnp_params.q8dw9.updw_per_channel
440                   : pytorch_qnnp_params.q8dw9.updw,
441               .multipass_ukernel = convolution->per_channel
442                   ? pytorch_qnnp_params.q8dw25.mpdw_per_channel
443                   : pytorch_qnnp_params.q8dw25.mpdw,
444           };
445           pthreadpool_compute_2d(
446               threadpool,
447               (pthreadpool_function_2d_t)compute_dwconv2d_unipass,
448               &context,
449               batch_size,
450               convolution->output_height);
451           break;
452         }
453         case 25: {
454           struct q8dwconv2d_context context = {
455               .groups = groups,
456               .group_stride = group_stride,
457               .indirection_buffer =
458                   (const uint8_t**)convolution->indirection_buffer,
459               .indirection_buffer_row_stride = step_height,
460               .indirection_buffer_col_stride =
461                   kernel_height * step_width * sizeof(void*),
462               .packed_weights = packed_weights,
463               .output = output,
464               .output_height = convolution->output_height,
465               .output_width = convolution->output_width,
466               .output_row_stride =
467                   convolution->output_width * output_pixel_stride,
468               .output_col_increment =
469                   (output_pixel_stride - groups) * sizeof(uint8_t),
470               .quantization_params = conv_quantization_params,
471               .unipass_ukernel = convolution->per_channel
472                   ? pytorch_qnnp_params.q8dw9.updw_per_channel
473                   : pytorch_qnnp_params.q8dw9.updw,
474               .multipass_ukernel = convolution->per_channel
475                   ? pytorch_qnnp_params.q8dw25.mpdw_per_channel
476                   : pytorch_qnnp_params.q8dw25.mpdw,
477           };
478           pthreadpool_compute_2d(
479               threadpool,
480               (pthreadpool_function_2d_t)compute_dwconv2d_multiipass,
481               &context,
482               batch_size,
483               convolution->output_height);
484           break;
485         }
486         case 27: {
487           struct q8dwconv3d_context context = {
488               .groups = groups,
489               .group_stride = group_stride,
490               .indirection_buffer =
491                   (const uint8_t**)convolution->indirection_buffer,
492               .indirection_buffer_slice_stride =
493                   step_height * convolution->output_height,
494               .indirection_buffer_row_stride = step_height * sizeof(void*),
495               .indirection_buffer_col_stride =
496                   kernel_height * kernel_depth * step_width * sizeof(void*),
497               .packed_weights = packed_weights,
498               .output = output,
499               .output_depth = convolution->output_depth,
500               .output_height = convolution->output_height,
501               .output_width = convolution->output_width,
502               .output_slice_stride = convolution->output_height *
503                   convolution->output_width * output_pixel_stride,
504               .quantization_params = conv_quantization_params,
505               .multipass_ukernel = pytorch_qnnp_params.q8dw27.mpdw,
506           };
507           pthreadpool_compute_2d(
508               threadpool,
509               (pthreadpool_function_2d_t)compute_dwconv3d_multiipass,
510               &context,
511               batch_size,
512               convolution->output_depth);
513           break;
514         }
515         default:
516           PYTORCH_QNNP_UNREACHABLE;
517       }
518       break;
519     }
520     case pytorch_qnnp_ukernel_type_xzp_gemm: {
521       const size_t group_input_channels = convolution->group_input_channels;
522       const size_t group_output_channels = convolution->group_output_channels;
523       const uint32_t mr = pytorch_qnnp_params.q8conv_xzp.mr;
524       const uint32_t nr = pytorch_qnnp_params.q8conv_xzp.nr;
525       const uint32_t kr = pytorch_qnnp_params.q8conv_xzp.kr;
526       const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
527       const size_t n_stride = (group_output_channels + (nr - 1)) & -nr;
528 
529       /* compute input row sum */
530       const size_t input_size = input_depth * input_height * input_width;
531       int32_t* a_sum = (int32_t*)realloc(
532           convolution->a_sum,
533           sizeof(int32_t) * batch_size * groups * input_size);
534       if (a_sum == nullptr) {
535         pytorch_qnnp_log_error(
536             "failed to allocate %zu bytes for row sum data",
537             sizeof(int32_t) * batch_size * groups * input_size);
538         return pytorch_qnnp_status_out_of_memory;
539       }
540       convolution->a_sum = a_sum;
541       struct q8sum_rows_context context = {
542           .a = input,
543           .groups = groups,
544           .m = input_size,
545           .k = convolution->group_input_channels,
546           .a_stride = input_pixel_stride,
547           // XZP kernels are not supporting per channel quant.
548           // We dont really use XZP kernels ATM.
549           // Thus assigning the zero point of first channel.
550           .multiplier = (int32_t)-kernel_zero_points[0],
551           .a_sum = a_sum,
552           .a_sum_stride = input_size,
553           .ukernel = pytorch_qnnp_params.q8sum_rows.sum_rows,
554       };
555       pthreadpool_compute_3d_tiled(
556           threadpool,
557           (pthreadpool_function_3d_tiled_t)compute_sum_rows,
558           &context,
559           groups,
560           batch_size,
561           input_size,
562           1,
563           1,
564           pytorch_qnnp_params.q8sum_rows.m);
565 
566       struct q8gemm_xzp_context q8gemm_xzp_context = {
567           .k = convolution->group_input_channels,
568           .k_stride = k_stride,
569           .n = convolution->group_output_channels,
570           .n_stride = n_stride,
571           .a = input,
572           .a_stride = input_pixel_stride,
573           .packed_w = packed_weights,
574           .c = output,
575           .c_stride = output_pixel_stride,
576           .a_sum = a_sum,
577           .groups = groups,
578           .batch_size = batch_size,
579           .a_sum_stride = input_size,
580           .requantization_params = requantization_params,
581           .ukernel = pytorch_qnnp_params.q8conv_xzp.gemm,
582       };
583       pthreadpool_compute_4d_tiled(
584           threadpool,
585           (pthreadpool_function_4d_tiled_t)compute_q8gemm_xzp,
586           &q8gemm_xzp_context,
587           groups,
588           batch_size * input_size,
589           input_size,
590           group_output_channels,
591           1,
592           input_size,
593           mr,
594           nr);
595       break;
596     }
597     case pytorch_qnnp_ukernel_type_gemm: {
598       const size_t group_input_channels = convolution->group_input_channels;
599       const size_t group_output_channels = convolution->group_output_channels;
600       const uint32_t mr = pytorch_qnnp_params.q8conv.mr;
601       const uint32_t nr = pytorch_qnnp_params.q8conv.nr;
602       const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
603       const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
604       const size_t n_stride = (group_output_channels + (nr - 1)) & -nr;
605 
606       struct q8gemm_context q8gemm_context = {
607           .k = convolution->group_input_channels,
608           .k_stride = k_stride,
609           .n = convolution->group_output_channels,
610           .n_stride = n_stride,
611           .a = input,
612           .a_stride = input_pixel_stride,
613           .packed_w = (uint8_t*)packed_weights,
614           .c = output,
615           .c_stride = output_pixel_stride,
616           .quantization_params = conv_quantization_params,
617           .ukernel = pytorch_qnnp_params.q8conv.gemm,
618       };
619 
620       pthreadpool_compute_4d_tiled(
621           threadpool,
622           (pthreadpool_function_4d_tiled_t)compute_q8gemm,
623           &q8gemm_context,
624           groups,
625           batch_size * output_size,
626           output_size,
627           group_output_channels,
628           1,
629           output_size,
630           mr,
631           nr);
632       break;
633     }
634     case pytorch_qnnp_ukernel_type_conv: {
635       const size_t group_input_channels = convolution->group_input_channels;
636       const size_t group_output_channels = convolution->group_output_channels;
637       const uint32_t mr = pytorch_qnnp_params.q8conv.mr;
638       const uint32_t nr = pytorch_qnnp_params.q8conv.nr;
639       const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
640       const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
641       const size_t n_stride = (group_output_channels + (nr - 1)) & -nr;
642       const size_t m_stride = round_up(output_size, mr);
643 
644       struct q8conv_context q8conv_context = {
645           .bs = batch_size,
646           .ks = kernel_size,
647           .kc = group_input_channels,
648           .kc_stride = k_stride * kernel_size,
649           .m = output_size,
650           .m_stride = m_stride,
651           .n = group_output_channels,
652           .n_stride = n_stride,
653           .indirect_a = (const uint8_t**)convolution->indirection_buffer,
654           .packed_w = packed_weights,
655           .c = output,
656           .c_stride = output_pixel_stride,
657           .quantization_params = conv_quantization_params,
658           .ukernel = pytorch_qnnp_params.q8conv.conv,
659       };
660 
661       pthreadpool_compute_4d_tiled(
662           threadpool,
663           (pthreadpool_function_4d_tiled_t)compute_q8conv,
664           &q8conv_context,
665           groups,
666           batch_size,
667           output_size,
668           group_output_channels,
669           1,
670           1,
671           mr,
672           nr);
673       break;
674     }
675     default: {
676       pytorch_qnnp_log_error("Invalid kernel type. QNNPACK convolution run failed.");
677       PYTORCH_QNNP_UNREACHABLE;
678     }
679   }
680   return pytorch_qnnp_status_success;
681 }
682 } // namespace qnnpack
683