1 #include <qnnpack/indirection.h>
2 #include <qnnpack/log.h>
3 #include <qnnpack/operator.h>
4 #include <qnnpack/pack.h>
5 #include <qnnpack_func.h>
6 #include <cstring>
7 #include <memory>
8 #include <numeric>
9
10 namespace qnnpack {
11
12 struct q8gemm_xzp_context {
13 size_t k;
14 size_t k_stride;
15 size_t n;
16 size_t n_stride;
17 const uint8_t* a;
18 size_t a_stride;
19 const void* packed_w;
20 uint8_t* c;
21 size_t c_stride;
22 const int32_t* a_sum;
23 size_t groups;
24 size_t batch_size;
25 size_t a_sum_stride;
26 union pytorch_qnnp_q31_requantization_params requantization_params;
27 const pytorch_q8gemm_xzp_ukernel_function ukernel;
28 };
compute_q8gemm_xzp(const struct q8gemm_xzp_context context[1],size_t group_index,size_t pixel_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t pixel_range,size_t mr_block_size,size_t nr_block_size)29 static void compute_q8gemm_xzp(
30 const struct q8gemm_xzp_context context[1],
31 size_t group_index,
32 size_t pixel_index,
33 size_t mr_block_start,
34 size_t nr_block_start,
35 size_t group_range /* always 1 */,
36 size_t pixel_range,
37 size_t mr_block_size,
38 size_t nr_block_size) {
39 const size_t k = context->k;
40 const size_t k_stride = context->k_stride;
41 const size_t n = context->n;
42 const size_t n_stride = context->n_stride;
43 const uint8_t* a = context->a;
44 const size_t a_stride = context->a_stride;
45 const void* packed_w = context->packed_w;
46 uint8_t* c = context->c;
47 const size_t c_stride = context->c_stride;
48 const int32_t* a_sum = context->a_sum;
49 const size_t groups = context->groups;
50 const size_t a_sum_stride = context->a_sum_stride;
51
52 context->ukernel(
53 mr_block_size,
54 nr_block_size,
55 k,
56 a + (pixel_index + mr_block_start) * a_stride + group_index * k,
57 a_stride,
58 a_sum + pixel_index * groups + group_index * a_sum_stride +
59 mr_block_start,
60 (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (k_stride * sizeof(uint8_t) + sizeof(int32_t))),
61 c + (pixel_index + mr_block_start) * c_stride + nr_block_start +
62 group_index * n,
63 c_stride,
64 &context->requantization_params);
65 }
66
67 struct q8gemm_context {
68 size_t k;
69 size_t k_stride;
70 size_t n;
71 size_t n_stride;
72 const uint8_t* a;
73 size_t a_stride;
74 const uint8_t* packed_w;
75 uint8_t* c;
76 size_t c_stride;
77 union pytorch_qnnp_conv_quantization_params quantization_params;
78 const pytorch_q8gemm_ukernel_function ukernel;
79 };
compute_q8gemm(const struct q8gemm_context context[1],size_t group_index,size_t pixel_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t pixel_range,size_t mr_block_size,size_t nr_block_size)80 static void compute_q8gemm(
81 const struct q8gemm_context context[1],
82 size_t group_index,
83 size_t pixel_index,
84 size_t mr_block_start,
85 size_t nr_block_start,
86 size_t group_range /* always 1 */,
87 size_t pixel_range,
88 size_t mr_block_size,
89 size_t nr_block_size) {
90 const size_t k = context->k;
91 const size_t k_stride = context->k_stride;
92 const size_t n = context->n;
93 const size_t n_stride = context->n_stride;
94 const uint8_t* a = context->a;
95 const size_t a_stride = context->a_stride;
96 const void* packed_w = context->packed_w;
97 uint8_t* c = context->c;
98 const size_t c_stride = context->c_stride;
99
100 const size_t output_channel_index = nr_block_start + group_index * n;
101 context->ukernel(
102 mr_block_size,
103 nr_block_size,
104 k,
105 a + (pixel_index + mr_block_start) * a_stride + group_index * k,
106 a_stride,
107 (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (k_stride * sizeof(uint8_t) + sizeof(int32_t))),
108 c + (pixel_index + mr_block_start) * c_stride + nr_block_start +
109 group_index * n,
110 c_stride,
111 output_channel_index,
112 &context->quantization_params);
113 }
114
115 struct q8conv_context {
116 size_t bs;
117 size_t ks;
118 size_t kc;
119 size_t kc_stride;
120 size_t m;
121 size_t m_stride;
122 size_t n;
123 size_t n_stride;
124 const uint8_t** indirect_a;
125 const void* packed_w;
126 uint8_t* c;
127 size_t c_stride;
128 union pytorch_qnnp_conv_quantization_params quantization_params;
129 const pytorch_q8conv_ukernel_function ukernel;
130 };
compute_q8conv(const struct q8conv_context context[1],size_t group_index,size_t image_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t image_range,size_t mr_block_size,size_t nr_block_size)131 static void compute_q8conv(
132 const struct q8conv_context context[1],
133 size_t group_index,
134 size_t image_index,
135 size_t mr_block_start,
136 size_t nr_block_start,
137 size_t group_range /* always 1 */,
138 size_t image_range /* always 1 */,
139 size_t mr_block_size,
140 size_t nr_block_size) {
141 const size_t bs = context->bs;
142 const size_t ks = context->ks;
143 const size_t kc = context->kc;
144 const size_t kc_stride = context->kc_stride;
145 const size_t m = context->m;
146 const size_t m_stride = context->m_stride;
147 const size_t n = context->n;
148 const size_t n_stride = context->n_stride;
149 const uint8_t** indirect_a = context->indirect_a;
150 const void* packed_w = context->packed_w;
151 uint8_t* c = context->c;
152 const size_t c_stride = context->c_stride;
153
154 const size_t output_channel_index = group_index * n + nr_block_start;
155 context->ukernel(
156 mr_block_size,
157 nr_block_size,
158 kc,
159 ks,
160 indirect_a +
161 (mr_block_start + (image_index + group_index * bs) * m_stride) * ks,
162 (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (kc_stride * sizeof(uint8_t) + sizeof(int32_t))),
163 c + (mr_block_start + image_index * m) * c_stride + group_index * n +
164 nr_block_start,
165 c_stride,
166 output_channel_index,
167 &context->quantization_params);
168 }
169
170 struct q8sum_rows_context {
171 const uint8_t* a;
172 size_t groups;
173 size_t m;
174 size_t k;
175 size_t a_stride;
176 const int32_t multiplier;
177 int32_t* a_sum;
178 size_t a_sum_stride;
179 const pytorch_q8sum_rows_ukernel_function ukernel;
180 };
compute_sum_rows(const struct q8sum_rows_context context[1],size_t group_index,size_t batch_index,size_t block_start,size_t group_range,size_t batch_range,size_t block_size)181 static void compute_sum_rows(
182 const struct q8sum_rows_context context[1],
183 size_t group_index,
184 size_t batch_index,
185 size_t block_start,
186 size_t group_range /* always 1 */,
187 size_t batch_range /* always 1 */,
188 size_t block_size) {
189 const uint8_t* a = context->a;
190 const size_t groups = context->groups;
191 const size_t m = context->m;
192 const size_t k = context->k;
193 const size_t a_stride = context->a_stride;
194 const int32_t multiplier = context->multiplier;
195 int32_t* a_sum = context->a_sum;
196 const size_t a_sum_stride = context->a_sum_stride;
197
198 context->ukernel(
199 a + batch_index * m * a_stride + group_index * k + block_start * a_stride,
200 min(block_size, m - block_start),
201 k,
202 a_stride,
203 multiplier,
204 a_sum + batch_index * groups * a_sum_stride + group_index * a_sum_stride +
205 block_start);
206 }
207
208 struct q8dwconv2d_context {
209 size_t groups;
210 size_t group_stride;
211 const uint8_t** indirection_buffer;
212 size_t indirection_buffer_row_stride;
213 size_t indirection_buffer_col_stride;
214 const void* packed_weights;
215 uint8_t* output;
216 size_t output_height;
217 size_t output_width;
218 size_t output_row_stride;
219 size_t output_col_increment;
220 union pytorch_qnnp_conv_quantization_params quantization_params;
221 const pytorch_q8dwconv2d_up_ukernel_function unipass_ukernel;
222 const pytorch_q8dwconv2d_mp_ukernel_function multipass_ukernel;
223 };
224
225 struct q8dwconv3d_context {
226 size_t groups;
227 size_t group_stride;
228 const uint8_t** indirection_buffer;
229 size_t indirection_buffer_slice_stride;
230 size_t indirection_buffer_row_stride;
231 size_t indirection_buffer_col_stride;
232 const void* packed_weights;
233 uint8_t* output;
234 size_t output_depth;
235 size_t output_height;
236 size_t output_width;
237 size_t output_slice_stride;
238 union pytorch_qnnp_conv_quantization_params quantization_params;
239 const pytorch_q8dwconv3d_mp_ukernel_function multipass_ukernel;
240 };
241
compute_dwconv2d_unipass(const struct q8dwconv2d_context context[1],size_t image,size_t output_y)242 static void compute_dwconv2d_unipass(
243 const struct q8dwconv2d_context context[1],
244 size_t image,
245 size_t output_y) {
246 const size_t output_height = context->output_height;
247
248 context->unipass_ukernel(
249 context->groups,
250 context->output_width,
251 context->indirection_buffer +
252 (image * output_height + output_y) *
253 context->indirection_buffer_row_stride,
254 context->packed_weights,
255 context->output +
256 (image * output_height + output_y) * context->output_row_stride,
257 context->indirection_buffer_col_stride,
258 context->output_col_increment,
259 &context->quantization_params);
260 }
compute_dwconv2d_multiipass(const struct q8dwconv2d_context context[1],size_t image,size_t output_y)261 static void compute_dwconv2d_multiipass(
262 const struct q8dwconv2d_context context[1],
263 size_t image,
264 size_t output_y) {
265 const size_t output_height = context->output_height;
266 PYTORCH_QNNP_ALIGN(16)
267 #ifdef _MSC_VER
268 int32_t* multipass_acc = (int32_t*)_malloca(sizeof(int32_t) * context->group_stride);
269 #else
270 int32_t multipass_acc[context->group_stride];
271 #endif
272
273 context->multipass_ukernel(
274 context->groups,
275 context->output_width,
276 context->indirection_buffer +
277 (image * output_height + output_y) *
278 context->indirection_buffer_row_stride,
279 context->packed_weights,
280 multipass_acc,
281 context->output +
282 (image * output_height + output_y) * context->output_row_stride,
283 context->indirection_buffer_col_stride,
284 context->output_col_increment,
285 &context->quantization_params);
286
287 #ifdef _MSC_VER
288 _freea(multipass_acc);
289 #endif
290 }
291
compute_dwconv3d_multiipass(const struct q8dwconv3d_context context[1],size_t image,size_t output_z)292 static void compute_dwconv3d_multiipass(
293 const struct q8dwconv3d_context context[1],
294 size_t image,
295 size_t output_z) {
296 const size_t output_depth = context->output_depth;
297 PYTORCH_QNNP_ALIGN(16)
298 #ifdef _MSC_VER
299 int32_t* multipass_acc =
300 (int32_t*)_malloca(sizeof(int32_t) * context->group_stride);
301 #else
302 int32_t multipass_acc[context->group_stride];
303 #endif
304
305 context->multipass_ukernel(
306 context->groups,
307 context->output_height,
308 context->output_width,
309 context->indirection_buffer +
310 (image * output_depth + output_z) *
311 context->indirection_buffer_slice_stride,
312 context->packed_weights,
313 multipass_acc,
314 context->output +
315 (image * output_depth + output_z) * context->output_slice_stride,
316 context->indirection_buffer_row_stride,
317 context->indirection_buffer_col_stride,
318 0,
319 &context->quantization_params);
320
321 #ifdef _MSC_VER
322 _freea(multipass_acc);
323 #endif
324 }
325
326 struct QnnpackDeleter {
operator ()qnnpack::QnnpackDeleter327 void operator()(pytorch_qnnp_operator_t op) {
328 pytorch_qnnp_delete_operator(op);
329 }
330 };
331
qnnpackConv(const pytorch_qnnp_operator_t convolution,void * packed_weights,const size_t batch_size,const size_t input_depth,const size_t input_height,const size_t input_width,const uint8_t input_zero_point,const uint8_t * input,const uint8_t * kernel_zero_points,const float * requantization_scales,const uint8_t output_zero_point,const uint8_t output_min,const uint8_t output_max,uint8_t * output,pthreadpool_t threadpool)332 enum pytorch_qnnp_status qnnpackConv(
333 const pytorch_qnnp_operator_t convolution,
334 void* packed_weights,
335 const size_t batch_size,
336 const size_t input_depth,
337 const size_t input_height,
338 const size_t input_width,
339 const uint8_t input_zero_point,
340 const uint8_t* input,
341 const uint8_t* kernel_zero_points,
342 const float* requantization_scales,
343 const uint8_t output_zero_point,
344 const uint8_t output_min,
345 const uint8_t output_max,
346 uint8_t* output,
347 pthreadpool_t threadpool) {
348 const size_t groups = convolution->groups;
349 const size_t input_pixel_stride = convolution->group_input_channels * groups;
350 const size_t output_pixel_stride =
351 convolution->group_output_channels * groups;
352 const size_t kernel_width = convolution->kernel_width;
353 const size_t kernel_height = convolution->kernel_height;
354 const size_t kernel_depth = convolution->kernel_depth;
355 const size_t kernel_size = kernel_height * kernel_width * kernel_depth;
356
357 if (batch_size == 0) {
358 // If no batches, return
359 return pytorch_qnnp_status_success;
360 }
361
362 union pytorch_qnnp_q31_requantization_params requantization_params {};
363 union pytorch_qnnp_conv_quantization_params conv_quantization_params {};
364 if (convolution->ukernel_type == pytorch_qnnp_ukernel_type_xzp_gemm) {
365 requantization_params = pytorch_qnnp_compute_requantization_params(
366 // Note. XZP kernels are not changed for per channel quant.
367 requantization_scales[0],
368 output_zero_point,
369 output_min,
370 output_max);
371 } else {
372 conv_quantization_params = pytorch_qnnp_compute_conv_quantization_params(
373 input_zero_point,
374 kernel_zero_points,
375 requantization_scales,
376 output_zero_point,
377 output_min,
378 output_max);
379 }
380
381 // Convolution op caches a few things.
382 // We need to check if the corresponding values on this
383 // invocation is same as cached values.
384 // If so we can skip setup step.
385 if (convolution->input != input || convolution->batch_size != batch_size ||
386 convolution->input_depth != input_depth ||
387 convolution->input_height != input_height ||
388 convolution->input_width != input_width ||
389 convolution->input_pixel_stride != input_pixel_stride) {
390 pytorch_qnnp_status status = pytorch_qnnp_setup_convolution_ndhwc_q8(
391 convolution,
392 batch_size,
393 input_depth,
394 input_height,
395 input_width,
396 input,
397 input_pixel_stride,
398 output,
399 output_pixel_stride,
400 threadpool);
401 if (status != pytorch_qnnp_status_success) {
402 pytorch_qnnp_log_error(
403 "failed to run convolution op setup to setup indirection buffer.");
404 return status;
405 }
406 }
407
408 const size_t output_size = convolution->output_height *
409 convolution->output_width * convolution->output_depth;
410
411 switch (convolution->ukernel_type) {
412 case pytorch_qnnp_ukernel_type_dwconv: {
413 const uint32_t cr = pytorch_qnnp_params.q8dw9.cr;
414 const size_t group_stride = (groups + (cr - 1)) & -cr;
415
416 const size_t step_height = convolution->step_height;
417 const size_t step_width = convolution->step_width;
418
419 switch (kernel_size) {
420 case 9: {
421 struct q8dwconv2d_context context = {
422 .groups = groups,
423 .group_stride = group_stride,
424 .indirection_buffer =
425 (const uint8_t**)convolution->indirection_buffer,
426 .indirection_buffer_row_stride = step_height,
427 .indirection_buffer_col_stride =
428 kernel_height * step_width * sizeof(void*),
429 .packed_weights = packed_weights,
430 .output = output,
431 .output_height = convolution->output_height,
432 .output_width = convolution->output_width,
433 .output_row_stride =
434 convolution->output_width * output_pixel_stride,
435 .output_col_increment =
436 (output_pixel_stride - groups) * sizeof(uint8_t),
437 .quantization_params = conv_quantization_params,
438 .unipass_ukernel = convolution->per_channel
439 ? pytorch_qnnp_params.q8dw9.updw_per_channel
440 : pytorch_qnnp_params.q8dw9.updw,
441 .multipass_ukernel = convolution->per_channel
442 ? pytorch_qnnp_params.q8dw25.mpdw_per_channel
443 : pytorch_qnnp_params.q8dw25.mpdw,
444 };
445 pthreadpool_compute_2d(
446 threadpool,
447 (pthreadpool_function_2d_t)compute_dwconv2d_unipass,
448 &context,
449 batch_size,
450 convolution->output_height);
451 break;
452 }
453 case 25: {
454 struct q8dwconv2d_context context = {
455 .groups = groups,
456 .group_stride = group_stride,
457 .indirection_buffer =
458 (const uint8_t**)convolution->indirection_buffer,
459 .indirection_buffer_row_stride = step_height,
460 .indirection_buffer_col_stride =
461 kernel_height * step_width * sizeof(void*),
462 .packed_weights = packed_weights,
463 .output = output,
464 .output_height = convolution->output_height,
465 .output_width = convolution->output_width,
466 .output_row_stride =
467 convolution->output_width * output_pixel_stride,
468 .output_col_increment =
469 (output_pixel_stride - groups) * sizeof(uint8_t),
470 .quantization_params = conv_quantization_params,
471 .unipass_ukernel = convolution->per_channel
472 ? pytorch_qnnp_params.q8dw9.updw_per_channel
473 : pytorch_qnnp_params.q8dw9.updw,
474 .multipass_ukernel = convolution->per_channel
475 ? pytorch_qnnp_params.q8dw25.mpdw_per_channel
476 : pytorch_qnnp_params.q8dw25.mpdw,
477 };
478 pthreadpool_compute_2d(
479 threadpool,
480 (pthreadpool_function_2d_t)compute_dwconv2d_multiipass,
481 &context,
482 batch_size,
483 convolution->output_height);
484 break;
485 }
486 case 27: {
487 struct q8dwconv3d_context context = {
488 .groups = groups,
489 .group_stride = group_stride,
490 .indirection_buffer =
491 (const uint8_t**)convolution->indirection_buffer,
492 .indirection_buffer_slice_stride =
493 step_height * convolution->output_height,
494 .indirection_buffer_row_stride = step_height * sizeof(void*),
495 .indirection_buffer_col_stride =
496 kernel_height * kernel_depth * step_width * sizeof(void*),
497 .packed_weights = packed_weights,
498 .output = output,
499 .output_depth = convolution->output_depth,
500 .output_height = convolution->output_height,
501 .output_width = convolution->output_width,
502 .output_slice_stride = convolution->output_height *
503 convolution->output_width * output_pixel_stride,
504 .quantization_params = conv_quantization_params,
505 .multipass_ukernel = pytorch_qnnp_params.q8dw27.mpdw,
506 };
507 pthreadpool_compute_2d(
508 threadpool,
509 (pthreadpool_function_2d_t)compute_dwconv3d_multiipass,
510 &context,
511 batch_size,
512 convolution->output_depth);
513 break;
514 }
515 default:
516 PYTORCH_QNNP_UNREACHABLE;
517 }
518 break;
519 }
520 case pytorch_qnnp_ukernel_type_xzp_gemm: {
521 const size_t group_input_channels = convolution->group_input_channels;
522 const size_t group_output_channels = convolution->group_output_channels;
523 const uint32_t mr = pytorch_qnnp_params.q8conv_xzp.mr;
524 const uint32_t nr = pytorch_qnnp_params.q8conv_xzp.nr;
525 const uint32_t kr = pytorch_qnnp_params.q8conv_xzp.kr;
526 const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
527 const size_t n_stride = (group_output_channels + (nr - 1)) & -nr;
528
529 /* compute input row sum */
530 const size_t input_size = input_depth * input_height * input_width;
531 int32_t* a_sum = (int32_t*)realloc(
532 convolution->a_sum,
533 sizeof(int32_t) * batch_size * groups * input_size);
534 if (a_sum == nullptr) {
535 pytorch_qnnp_log_error(
536 "failed to allocate %zu bytes for row sum data",
537 sizeof(int32_t) * batch_size * groups * input_size);
538 return pytorch_qnnp_status_out_of_memory;
539 }
540 convolution->a_sum = a_sum;
541 struct q8sum_rows_context context = {
542 .a = input,
543 .groups = groups,
544 .m = input_size,
545 .k = convolution->group_input_channels,
546 .a_stride = input_pixel_stride,
547 // XZP kernels are not supporting per channel quant.
548 // We dont really use XZP kernels ATM.
549 // Thus assigning the zero point of first channel.
550 .multiplier = (int32_t)-kernel_zero_points[0],
551 .a_sum = a_sum,
552 .a_sum_stride = input_size,
553 .ukernel = pytorch_qnnp_params.q8sum_rows.sum_rows,
554 };
555 pthreadpool_compute_3d_tiled(
556 threadpool,
557 (pthreadpool_function_3d_tiled_t)compute_sum_rows,
558 &context,
559 groups,
560 batch_size,
561 input_size,
562 1,
563 1,
564 pytorch_qnnp_params.q8sum_rows.m);
565
566 struct q8gemm_xzp_context q8gemm_xzp_context = {
567 .k = convolution->group_input_channels,
568 .k_stride = k_stride,
569 .n = convolution->group_output_channels,
570 .n_stride = n_stride,
571 .a = input,
572 .a_stride = input_pixel_stride,
573 .packed_w = packed_weights,
574 .c = output,
575 .c_stride = output_pixel_stride,
576 .a_sum = a_sum,
577 .groups = groups,
578 .batch_size = batch_size,
579 .a_sum_stride = input_size,
580 .requantization_params = requantization_params,
581 .ukernel = pytorch_qnnp_params.q8conv_xzp.gemm,
582 };
583 pthreadpool_compute_4d_tiled(
584 threadpool,
585 (pthreadpool_function_4d_tiled_t)compute_q8gemm_xzp,
586 &q8gemm_xzp_context,
587 groups,
588 batch_size * input_size,
589 input_size,
590 group_output_channels,
591 1,
592 input_size,
593 mr,
594 nr);
595 break;
596 }
597 case pytorch_qnnp_ukernel_type_gemm: {
598 const size_t group_input_channels = convolution->group_input_channels;
599 const size_t group_output_channels = convolution->group_output_channels;
600 const uint32_t mr = pytorch_qnnp_params.q8conv.mr;
601 const uint32_t nr = pytorch_qnnp_params.q8conv.nr;
602 const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
603 const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
604 const size_t n_stride = (group_output_channels + (nr - 1)) & -nr;
605
606 struct q8gemm_context q8gemm_context = {
607 .k = convolution->group_input_channels,
608 .k_stride = k_stride,
609 .n = convolution->group_output_channels,
610 .n_stride = n_stride,
611 .a = input,
612 .a_stride = input_pixel_stride,
613 .packed_w = (uint8_t*)packed_weights,
614 .c = output,
615 .c_stride = output_pixel_stride,
616 .quantization_params = conv_quantization_params,
617 .ukernel = pytorch_qnnp_params.q8conv.gemm,
618 };
619
620 pthreadpool_compute_4d_tiled(
621 threadpool,
622 (pthreadpool_function_4d_tiled_t)compute_q8gemm,
623 &q8gemm_context,
624 groups,
625 batch_size * output_size,
626 output_size,
627 group_output_channels,
628 1,
629 output_size,
630 mr,
631 nr);
632 break;
633 }
634 case pytorch_qnnp_ukernel_type_conv: {
635 const size_t group_input_channels = convolution->group_input_channels;
636 const size_t group_output_channels = convolution->group_output_channels;
637 const uint32_t mr = pytorch_qnnp_params.q8conv.mr;
638 const uint32_t nr = pytorch_qnnp_params.q8conv.nr;
639 const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
640 const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
641 const size_t n_stride = (group_output_channels + (nr - 1)) & -nr;
642 const size_t m_stride = round_up(output_size, mr);
643
644 struct q8conv_context q8conv_context = {
645 .bs = batch_size,
646 .ks = kernel_size,
647 .kc = group_input_channels,
648 .kc_stride = k_stride * kernel_size,
649 .m = output_size,
650 .m_stride = m_stride,
651 .n = group_output_channels,
652 .n_stride = n_stride,
653 .indirect_a = (const uint8_t**)convolution->indirection_buffer,
654 .packed_w = packed_weights,
655 .c = output,
656 .c_stride = output_pixel_stride,
657 .quantization_params = conv_quantization_params,
658 .ukernel = pytorch_qnnp_params.q8conv.conv,
659 };
660
661 pthreadpool_compute_4d_tiled(
662 threadpool,
663 (pthreadpool_function_4d_tiled_t)compute_q8conv,
664 &q8conv_context,
665 groups,
666 batch_size,
667 output_size,
668 group_output_channels,
669 1,
670 1,
671 mr,
672 nr);
673 break;
674 }
675 default: {
676 pytorch_qnnp_log_error("Invalid kernel type. QNNPACK convolution run failed.");
677 PYTORCH_QNNP_UNREACHABLE;
678 }
679 }
680 return pytorch_qnnp_status_success;
681 }
682 } // namespace qnnpack
683