1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10 #include <stddef.h>
11 #include <stdint.h>
12 #include <string.h>
13
14 #include <pytorch_qnnpack.h>
15 #include <qnnpack/common.h>
16 #include <qnnpack/log.h>
17 #include <qnnpack/math.h>
18 #include <qnnpack/operator.h>
19 #include <qnnpack/params.h>
20
21 #ifdef _MSC_VER
22 #include <malloc.h>
23 #endif
24
25 struct q8gemm_context {
26 size_t k;
27 size_t k_stride;
28 size_t n;
29 size_t n_stride;
30 const uint8_t* a;
31 size_t a_stride;
32 const uint8_t* packed_w;
33 uint8_t* c;
34 size_t c_stride;
35 union pytorch_qnnp_conv_quantization_params quantization_params;
36 const pytorch_q8gemm_ukernel_function ukernel;
37 };
38
compute_q8gemm(const struct q8gemm_context context[RESTRICT_STATIC1],size_t group_index,size_t pixel_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t pixel_range,size_t mr_block_size,size_t nr_block_size)39 static void compute_q8gemm(
40 const struct q8gemm_context context[RESTRICT_STATIC 1],
41 size_t group_index,
42 size_t pixel_index,
43 size_t mr_block_start,
44 size_t nr_block_start,
45 size_t group_range /* always 1 */,
46 size_t pixel_range,
47 size_t mr_block_size,
48 size_t nr_block_size) {
49 const size_t k = context->k;
50 const size_t k_stride = context->k_stride;
51 const size_t n = context->n;
52 const size_t n_stride = context->n_stride;
53 const uint8_t* restrict a = context->a;
54 const size_t a_stride = context->a_stride;
55 const void* restrict packed_w = context->packed_w;
56 uint8_t* restrict c = context->c;
57 const size_t c_stride = context->c_stride;
58
59 size_t output_channel_index = nr_block_start + group_index * n;
60 context->ukernel(
61 mr_block_size,
62 nr_block_size,
63 k,
64 a + (pixel_index + mr_block_start) * a_stride + group_index * k,
65 a_stride,
66 (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (k_stride * sizeof(uint8_t) + sizeof(int32_t))),
67 c + (pixel_index + mr_block_start) * c_stride + nr_block_start +
68 group_index * n,
69 c_stride,
70 output_channel_index,
71 &context->quantization_params);
72 }
73
74 // At the moment we opt to remove sparse kernels that
75 // dont require prepacking as their perf was always
76 // worse.
77 #ifdef NO_PREPACK_SPARSE_KERNEL
78 struct q8gemm_sparse_dq_context {
79 const uint8_t* a;
80 size_t a_stride;
81 const uint32_t* kernel_col_indices;
82 const uint32_t* kernel_row_values;
83 const uint8_t* kernel_values;
84 const float* bias;
85 float* c; // can be float or uint8)t
86 size_t c_stride;
87 struct pytorch_qnnp_conv_dynamic_quantization_params quantization_params;
88 const pytorch_q8gemm_dq_sparse_ukernel_function ukernel;
89 };
90
compute_q8gemm_sparse_dq(const struct q8gemm_sparse_dq_context context[RESTRICT_STATIC1],size_t group_index,size_t pixel_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t pixel_range,size_t mr_block_size,size_t nr_block_size)91 static void compute_q8gemm_sparse_dq(
92 const struct q8gemm_sparse_dq_context context[RESTRICT_STATIC 1],
93 size_t group_index, /* ignored */
94 size_t pixel_index, /* ignored */
95 size_t mr_block_start,
96 size_t nr_block_start,
97 size_t group_range /* always 1 */,
98 size_t pixel_range,
99 size_t mr_block_size,
100 size_t nr_block_size) {
101 const uint8_t* restrict a = context->a;
102 const size_t a_stride = context->a_stride;
103 float* restrict c = (float*)context->c;
104 const size_t c_stride = context->c_stride;
105
106 size_t output_channel_index = nr_block_start;
107 context->ukernel(
108 mr_block_size,
109 nr_block_size,
110 a + mr_block_start * a_stride,
111 a_stride,
112 context->kernel_values,
113 context->kernel_row_values + nr_block_start,
114 context->kernel_col_indices,
115 context->bias + nr_block_start,
116 c + mr_block_start * c_stride + nr_block_start,
117 c_stride,
118 output_channel_index,
119 &context->quantization_params);
120 }
121 #endif
122
123 struct q8gemm_prepackA_sparse_dq_context {
124 size_t k;
125 const uint8_t* a;
126 size_t a_stride;
127 uint8_t* a_packed;
128 size_t a_packed_stride;
129 size_t log2_mr;
130 size_t log2_row_block_size;
131 union {
132 const uint32_t* kernel_col_indices_w32;
133 const uint16_t* kernel_col_indices_w16;
134 const uint8_t* kernel_col_indices_w8;
135 };
136 union {
137 const uint32_t* kernel_row_values_w32;
138 const uint16_t* kernel_row_values_w16;
139 const uint8_t* kernel_row_values_w8;
140 };
141 enum pytorch_qnnp_sparse_matrix_indices_dtype kernel_indices_dtype;
142 const uint8_t* kernel_values;
143 const float* bias;
144 float* c; // can be float or uint8)t
145 size_t c_stride;
146 struct pytorch_qnnp_conv_dynamic_quantization_params quantization_params;
147 union {
148 // Not const because assigned after context is initialized
149 pytorch_q8gemm_dq_sparse_packedA_w32_ukernel_function ukernel_w32;
150 pytorch_q8gemm_dq_sparse_packedA_w16_ukernel_function ukernel_w16;
151 pytorch_q8gemm_dq_sparse_packedA_w8_ukernel_function ukernel_w8;
152 };
153 const pytorch_q8gemm_sparse_packA_ukernel_function prepack_ukernel;
154 };
155
compute_q8gemm_prepack_a_sparse(const struct q8gemm_prepackA_sparse_dq_context context[RESTRICT_STATIC1],size_t group_index,size_t pixel_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t pixel_range,size_t mr_block_size,size_t nr_block_size)156 static void compute_q8gemm_prepack_a_sparse(
157 const struct q8gemm_prepackA_sparse_dq_context context[RESTRICT_STATIC 1],
158 size_t group_index, /* ignored */
159 size_t pixel_index, /* ignored */
160 size_t mr_block_start,
161 size_t nr_block_start,
162 size_t group_range /* always 1 */,
163 size_t pixel_range,
164 size_t mr_block_size,
165 size_t nr_block_size) {
166 const uint8_t* restrict a = context->a;
167 const size_t a_stride = context->a_stride;
168 const size_t mr_packed_block_start =
169 ((mr_block_start >> context->log2_mr) * context->a_packed_stride);
170
171 context->prepack_ukernel(
172 mr_block_size,
173 context->k,
174 a + mr_block_start * a_stride,
175 a_stride,
176 context->a_packed + mr_packed_block_start);
177 }
178
compute_q8gemm_prepacked_sparse_dq(const struct q8gemm_prepackA_sparse_dq_context context[RESTRICT_STATIC1],size_t group_index,size_t pixel_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t pixel_range,size_t mr_block_size,size_t nr_block_size)179 static void compute_q8gemm_prepacked_sparse_dq(
180 const struct q8gemm_prepackA_sparse_dq_context context[RESTRICT_STATIC 1],
181 size_t group_index, /* ignored */
182 size_t pixel_index, /* ignored */
183 size_t mr_block_start,
184 size_t nr_block_start,
185 size_t group_range /* always 1 */,
186 size_t pixel_range,
187 size_t mr_block_size,
188 size_t nr_block_size) {
189 const size_t mr_packed_block_start =
190 ((mr_block_start >> context->log2_mr) * context->a_packed_stride);
191 const uint8_t* restrict a_packed = context->a_packed + mr_packed_block_start;
192 const size_t c_stride = context->c_stride;
193 float* restrict c =
194 ((float*)context->c) + mr_block_start * c_stride + nr_block_start;
195 const size_t kernel_row_values_shift =
196 nr_block_start >> context->log2_row_block_size;
197 const float* bias = context->bias + nr_block_start;
198 const size_t output_channel_index = nr_block_start;
199
200 switch (context->kernel_indices_dtype) {
201 case pytorch_qnnp_sparse_matrix_indices_dtype_uint32_t:
202 context->ukernel_w32(
203 mr_block_size,
204 nr_block_size,
205 a_packed,
206 context->kernel_values,
207 context->kernel_row_values_w32 + kernel_row_values_shift,
208 context->kernel_col_indices_w32,
209 bias,
210 c,
211 c_stride,
212 output_channel_index,
213 &context->quantization_params);
214 break;
215 case pytorch_qnnp_sparse_matrix_indices_dtype_uint16_t:
216 context->ukernel_w16(
217 mr_block_size,
218 nr_block_size,
219 a_packed,
220 context->kernel_values,
221 context->kernel_row_values_w16 + kernel_row_values_shift,
222 context->kernel_col_indices_w16,
223 bias,
224 c,
225 c_stride,
226 output_channel_index,
227 &context->quantization_params);
228 break;
229 case pytorch_qnnp_sparse_matrix_indices_dtype_uint8_t:
230 context->ukernel_w8(
231 mr_block_size,
232 nr_block_size,
233 a_packed,
234 context->kernel_values,
235 context->kernel_row_values_w8 + kernel_row_values_shift,
236 context->kernel_col_indices_w8,
237 bias,
238 c,
239 c_stride,
240 output_channel_index,
241 &context->quantization_params);
242 break;
243 case pytorch_qnnp_sparse_matrix_indices_dtype_invalid:
244 // This function can not return an error code without substantially
245 // changing the internal API. A check for invalid index type should
246 // already exist in the calling function. If the code reaches here, then
247 // please add / restore the index check in the calling function.
248 pytorch_qnnp_log_error(
249 "Invalid indices dtype specified for "
250 "operator-run compute_q8gemm_prepacked_sparse_dq");
251 assert(false);
252 }
253 }
254
255 struct q8sum_rows_context {
256 const uint8_t* a;
257 size_t groups;
258 size_t m;
259 size_t k;
260 size_t a_stride;
261 const int32_t multiplier;
262 int32_t* a_sum;
263 size_t a_sum_stride;
264 const pytorch_q8sum_rows_ukernel_function ukernel;
265 };
266
compute_sum_rows(const struct q8sum_rows_context context[RESTRICT_STATIC1],size_t group_index,size_t batch_index,size_t block_start,size_t group_range,size_t batch_range,size_t block_size)267 static void compute_sum_rows(
268 const struct q8sum_rows_context context[RESTRICT_STATIC 1],
269 size_t group_index,
270 size_t batch_index,
271 size_t block_start,
272 size_t group_range /* always 1 */,
273 size_t batch_range /* always 1 */,
274 size_t block_size) {
275 const uint8_t* a = context->a;
276 const size_t groups = context->groups;
277 const size_t m = context->m;
278 const size_t k = context->k;
279 const size_t a_stride = context->a_stride;
280 const int32_t multiplier = context->multiplier;
281 int32_t* a_sum = context->a_sum;
282 const size_t a_sum_stride = context->a_sum_stride;
283
284 context->ukernel(
285 a + batch_index * m * a_stride + group_index * k + block_start * a_stride,
286 min(block_size, m - block_start),
287 k,
288 a_stride,
289 multiplier,
290 a_sum + batch_index * groups * a_sum_stride + group_index * a_sum_stride +
291 block_start);
292 }
293
294 struct q8gemm_xzp_context {
295 size_t k;
296 size_t k_stride;
297 size_t n;
298 size_t n_stride;
299 const uint8_t* a;
300 size_t a_stride;
301 const void* packed_w;
302 uint8_t* c;
303 size_t c_stride;
304 const int32_t* a_sum;
305 size_t groups;
306 size_t batch_size;
307 size_t a_sum_stride;
308 union pytorch_qnnp_q31_requantization_params requantization_params;
309 const pytorch_q8gemm_xzp_ukernel_function ukernel;
310 };
311
compute_q8gemm_xzp(const struct q8gemm_xzp_context context[RESTRICT_STATIC1],size_t group_index,size_t pixel_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t pixel_range,size_t mr_block_size,size_t nr_block_size)312 static void compute_q8gemm_xzp(
313 const struct q8gemm_xzp_context context[RESTRICT_STATIC 1],
314 size_t group_index,
315 size_t pixel_index,
316 size_t mr_block_start,
317 size_t nr_block_start,
318 size_t group_range /* always 1 */,
319 size_t pixel_range,
320 size_t mr_block_size,
321 size_t nr_block_size) {
322 const size_t k = context->k;
323 const size_t k_stride = context->k_stride;
324 const size_t n = context->n;
325 const size_t n_stride = context->n_stride;
326 const uint8_t* restrict a = context->a;
327 const size_t a_stride = context->a_stride;
328 const void* restrict packed_w = context->packed_w;
329 uint8_t* restrict c = context->c;
330 const size_t c_stride = context->c_stride;
331 const int32_t* a_sum = context->a_sum;
332 const size_t groups = context->groups;
333 const size_t a_sum_stride = context->a_sum_stride;
334
335 context->ukernel(
336 mr_block_size,
337 nr_block_size,
338 k,
339 a + (pixel_index + mr_block_start) * a_stride + group_index * k,
340 a_stride,
341 a_sum + pixel_index * groups + group_index * a_sum_stride +
342 mr_block_start,
343 (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (k_stride * sizeof(uint8_t) + sizeof(int32_t))),
344 c + (pixel_index + mr_block_start) * c_stride + nr_block_start +
345 group_index * n,
346 c_stride,
347 &context->requantization_params);
348 }
349
350 struct q8conv_context {
351 size_t bs;
352 size_t ks;
353 size_t kc;
354 size_t kc_stride;
355 size_t m;
356 size_t m_stride;
357 size_t n;
358 size_t n_stride;
359 const uint8_t** indirect_a;
360 const void* packed_w;
361 uint8_t* c;
362 size_t c_stride;
363 union pytorch_qnnp_conv_quantization_params quantization_params;
364 const pytorch_q8conv_ukernel_function ukernel;
365 };
366
compute_q8conv(const struct q8conv_context context[RESTRICT_STATIC1],size_t group_index,size_t image_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t image_range,size_t mr_block_size,size_t nr_block_size)367 static void compute_q8conv(
368 const struct q8conv_context context[RESTRICT_STATIC 1],
369 size_t group_index,
370 size_t image_index,
371 size_t mr_block_start,
372 size_t nr_block_start,
373 size_t group_range /* always 1 */,
374 size_t image_range /* always 1 */,
375 size_t mr_block_size,
376 size_t nr_block_size) {
377 const size_t bs = context->bs;
378 const size_t ks = context->ks;
379 const size_t kc = context->kc;
380 const size_t kc_stride = context->kc_stride;
381 const size_t m = context->m;
382 const size_t m_stride = context->m_stride;
383 const size_t n = context->n;
384 const size_t n_stride = context->n_stride;
385 const uint8_t** restrict indirect_a = context->indirect_a;
386 const void* restrict packed_w = context->packed_w;
387 uint8_t* restrict c = context->c;
388 const size_t c_stride = context->c_stride;
389
390 size_t output_channel_index = nr_block_start + group_index * n;
391 context->ukernel(
392 mr_block_size,
393 nr_block_size,
394 kc,
395 ks,
396 indirect_a +
397 (mr_block_start + (image_index + group_index * bs) * m_stride) * ks,
398 (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (kc_stride * sizeof(uint8_t) + sizeof(int32_t))),
399 c + (mr_block_start + image_index * m) * c_stride + group_index * n +
400 nr_block_start,
401 c_stride,
402 output_channel_index,
403 &context->quantization_params);
404 }
405
406 struct q8dwconv2d_context {
407 size_t groups;
408 size_t group_stride;
409 const uint8_t** indirection_buffer;
410 size_t indirection_buffer_row_stride;
411 size_t indirection_buffer_col_stride;
412 const void* packed_weights;
413 uint8_t* output;
414 size_t output_height;
415 size_t output_width;
416 size_t output_row_stride;
417 size_t output_col_increment;
418 union pytorch_qnnp_conv_quantization_params quantization_params;
419 union {
420 const pytorch_q8dwconv2d_up_ukernel_function unipass_ukernel;
421 const pytorch_q8dwconv2d_mp_ukernel_function multipass_ukernel;
422 };
423 };
424
425 struct q8dwconv3d_context {
426 size_t groups;
427 size_t group_stride;
428 const uint8_t** indirection_buffer;
429 size_t indirection_buffer_slice_stride;
430 size_t indirection_buffer_row_stride;
431 size_t indirection_buffer_col_stride;
432 const void* packed_weights;
433 uint8_t* output;
434 size_t output_depth;
435 size_t output_height;
436 size_t output_width;
437 size_t output_slice_stride;
438 union pytorch_qnnp_conv_quantization_params quantization_params;
439 const pytorch_q8dwconv3d_mp_ukernel_function multipass_ukernel;
440 };
441
compute_dwconv2d_unipass(const struct q8dwconv2d_context context[RESTRICT_STATIC1],size_t image,size_t output_y)442 static void compute_dwconv2d_unipass(
443 const struct q8dwconv2d_context context[RESTRICT_STATIC 1],
444 size_t image,
445 size_t output_y) {
446 const size_t output_height = context->output_height;
447
448 context->unipass_ukernel(
449 context->groups,
450 context->output_width,
451 context->indirection_buffer +
452 (image * output_height + output_y) *
453 context->indirection_buffer_row_stride,
454 context->packed_weights,
455 context->output +
456 (image * output_height + output_y) * context->output_row_stride,
457 context->indirection_buffer_col_stride,
458 context->output_col_increment,
459 &context->quantization_params);
460 }
461
compute_dwconv2d_multiipass(const struct q8dwconv2d_context context[RESTRICT_STATIC1],size_t image,size_t output_y)462 static void compute_dwconv2d_multiipass(
463 const struct q8dwconv2d_context context[RESTRICT_STATIC 1],
464 size_t image,
465 size_t output_y) {
466 const size_t output_height = context->output_height;
467 PYTORCH_QNNP_ALIGN(16)
468 #ifdef _MSC_VER
469 int32_t* multipass_acc = _malloca(sizeof(int32_t) * context->group_stride);
470 #else
471 int32_t multipass_acc[context->group_stride];
472 #endif
473
474 context->multipass_ukernel(
475 context->groups,
476 context->output_width,
477 context->indirection_buffer +
478 (image * output_height + output_y) *
479 context->indirection_buffer_row_stride,
480 context->packed_weights,
481 multipass_acc,
482 context->output +
483 (image * output_height + output_y) * context->output_row_stride,
484 context->indirection_buffer_col_stride,
485 context->output_col_increment,
486 &context->quantization_params);
487
488 #ifdef _MSC_VER
489 _freea(multipass_acc);
490 #endif
491 }
492
compute_dwconv3d_multiipass(const struct q8dwconv3d_context context[1],size_t image,size_t output_z)493 static void compute_dwconv3d_multiipass(
494 const struct q8dwconv3d_context context[1],
495 size_t image,
496 size_t output_z) {
497 const size_t output_depth = context->output_depth;
498 PYTORCH_QNNP_ALIGN(16)
499 #ifdef _MSC_VER
500 int32_t* multipass_acc =
501 (int32_t*)_malloca(sizeof(int32_t) * context->group_stride);
502 #else
503 int32_t multipass_acc[context->group_stride];
504 #endif
505
506 context->multipass_ukernel(
507 context->groups,
508 context->output_height,
509 context->output_width,
510 context->indirection_buffer +
511 (image * output_depth + output_z) *
512 context->indirection_buffer_slice_stride,
513 context->packed_weights,
514 multipass_acc,
515 context->output +
516 (image * output_depth + output_z) * context->output_slice_stride,
517 context->indirection_buffer_row_stride,
518 context->indirection_buffer_col_stride,
519 0,
520 &context->quantization_params);
521
522 #ifdef _MSC_VER
523 _freea(multipass_acc);
524 #endif
525 }
526
527 struct max_pooling_context {
528 const void** indirect_input;
529 size_t indirect_input_batch_stride;
530 size_t indirect_input_height_stride;
531 void* output;
532 size_t output_batch_stride;
533 size_t output_height_stride;
534 size_t output_width;
535 size_t pooling_size;
536 size_t channels;
537 size_t input_increment;
538 size_t output_increment;
539 union pytorch_qnnp_u8_clamping_params params;
540 pytorch_u8maxpool_ukernel_function ukernel;
541 };
542
compute_max_pooling(const struct max_pooling_context context[RESTRICT_STATIC1],size_t batch_index,size_t output_y)543 static void compute_max_pooling(
544 const struct max_pooling_context context[RESTRICT_STATIC 1],
545 size_t batch_index,
546 size_t output_y) {
547 const void** indirect_input =
548 (const void**) ((uintptr_t) context->indirect_input +
549 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
550 void* output =
551 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
552
553 context->ukernel(
554 context->output_width,
555 context->pooling_size,
556 context->channels,
557 (const uint8_t**)indirect_input,
558 output,
559 context->input_increment,
560 context->output_increment,
561 &context->params);
562 }
563
564 struct average_pooling_context {
565 const void** indirect_input;
566 size_t indirect_input_batch_stride;
567 size_t indirect_input_height_stride;
568 void* output;
569 size_t output_batch_stride;
570 size_t output_height_stride;
571 size_t output_width;
572 size_t pooling_size;
573 size_t channels;
574 size_t packed_channels;
575 const void* zero;
576 size_t input_increment;
577 size_t output_increment;
578 union pytorch_qnnp_avgpool_quantization_params quantization_params;
579 union {
580 pytorch_q8avgpool_up_ukernel_function unipass_ukernel;
581 pytorch_q8avgpool_mp_ukernel_function multipass_ukernel;
582 };
583 };
584
compute_average_pooling_unipass(const struct average_pooling_context context[RESTRICT_STATIC1],size_t batch_index,size_t output_y)585 static void compute_average_pooling_unipass(
586 const struct average_pooling_context context[RESTRICT_STATIC 1],
587 size_t batch_index,
588 size_t output_y) {
589 const void** indirect_input =
590 (const void**) ((uintptr_t) context->indirect_input +
591 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
592 void* output =
593 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
594
595 context->unipass_ukernel(
596 context->output_width,
597 context->pooling_size,
598 context->channels,
599 (const uint8_t**)indirect_input,
600 context->zero,
601 output,
602 context->input_increment,
603 context->output_increment,
604 &context->quantization_params);
605 }
606
compute_average_pooling_multipass(const struct average_pooling_context context[RESTRICT_STATIC1],size_t batch_index,size_t output_y)607 static void compute_average_pooling_multipass(
608 const struct average_pooling_context context[RESTRICT_STATIC 1],
609 size_t batch_index,
610 size_t output_y) {
611 const void** indirect_input =
612 (const void**) ((uintptr_t) context->indirect_input +
613 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
614 void* output =
615 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
616 PYTORCH_QNNP_ALIGN(16)
617 #ifdef _MSC_VER
618 int32_t* multipass_buffer =
619 _malloca(sizeof(int32_t) * context->packed_channels);
620 #else
621 int32_t multipass_buffer[context->packed_channels];
622 #endif
623
624 context->multipass_ukernel(
625 context->output_width,
626 context->pooling_size,
627 context->channels,
628 (const uint8_t**)indirect_input,
629 context->zero,
630 multipass_buffer,
631 output,
632 context->input_increment,
633 context->output_increment,
634 &context->quantization_params);
635
636 #ifdef _MSC_VER
637 _freea(multipass_buffer);
638 #endif
639 }
640
641 struct global_average_pooling_context {
642 const void* input;
643 const void* zero;
644 size_t input_pixel_stride;
645 size_t input_batch_stride;
646 size_t input_elements;
647 size_t channels;
648 size_t packed_channels;
649 void* output;
650 size_t output_batch_stride;
651 union pytorch_qnnp_avgpool_quantization_params quantization_params;
652 union {
653 pytorch_q8gavgpool_up_ukernel_function unipass_ukernel;
654 pytorch_q8gavgpool_mp_ukernel_function multipass_ukernel;
655 };
656 };
657
compute_global_average_pooling_unipass(const struct global_average_pooling_context context[RESTRICT_STATIC1],size_t batch_index)658 static void compute_global_average_pooling_unipass(
659 const struct global_average_pooling_context context[RESTRICT_STATIC 1],
660 size_t batch_index) {
661 const void* input =
662 (const void*)((uintptr_t)context->input + batch_index * context->input_batch_stride);
663 void* output =
664 (void*)((uintptr_t)context->output + batch_index * context->output_batch_stride);
665
666 context->unipass_ukernel(
667 context->input_elements,
668 context->channels,
669 input,
670 context->input_pixel_stride,
671 context->zero,
672 output,
673 &context->quantization_params);
674 }
675
compute_global_average_pooling_multipass(const struct global_average_pooling_context context[RESTRICT_STATIC1],size_t batch_index)676 static void compute_global_average_pooling_multipass(
677 const struct global_average_pooling_context context[RESTRICT_STATIC 1],
678 size_t batch_index) {
679 const void* input =
680 (const void*)((uintptr_t)context->input + batch_index * context->input_batch_stride);
681 void* output =
682 (void*)((uintptr_t)context->output + batch_index * context->output_batch_stride);
683 PYTORCH_QNNP_ALIGN(16)
684 #ifdef _MSC_VER
685 int32_t* multipass_buffer =
686 _malloca(sizeof(int32_t) * context->packed_channels);
687 #else
688 int32_t multipass_buffer[context->packed_channels];
689 #endif
690
691 context->multipass_ukernel(
692 context->input_elements,
693 context->channels,
694 input,
695 context->input_pixel_stride,
696 context->zero,
697 multipass_buffer,
698 output,
699 &context->quantization_params);
700
701 #ifdef _MSC_VER
702 _freea(multipass_buffer);
703 #endif
704 }
705
706 struct q8add_strided_context {
707 size_t n;
708 const uint8_t* a;
709 size_t a_stride;
710 const uint8_t* b;
711 size_t b_stride;
712 const uint8_t* y;
713 size_t y_stride;
714 union pytorch_qnnp_add_quantization_params quantization_params;
715 pytorch_q8vadd_ukernel_function ukernel;
716 };
717
compute_q8add_strided(const struct q8add_strided_context context[RESTRICT_STATIC1],size_t batch_offset,size_t batch_range)718 static void compute_q8add_strided(
719 const struct q8add_strided_context context[RESTRICT_STATIC 1],
720 size_t batch_offset,
721 size_t batch_range /* always 1 */) {
722 assert(batch_range == 1);
723
724 const size_t n = context->n;
725 const size_t a_stride = context->a_stride;
726 const size_t b_stride = context->b_stride;
727 const size_t y_stride = context->y_stride;
728 const void* a =
729 (const void*)((uintptr_t)context->a + a_stride * batch_offset);
730 const void* b =
731 (const void*)((uintptr_t)context->b + b_stride * batch_offset);
732 void* y = (void*)((uintptr_t)context->y + y_stride * batch_offset);
733
734 context->ukernel(n, a, b, y, &context->quantization_params);
735 }
736
737 struct q8add_contiguous_context {
738 const uint8_t* a;
739 const uint8_t* b;
740 uint8_t* y;
741 union pytorch_qnnp_add_quantization_params quantization_params;
742 pytorch_q8vadd_ukernel_function ukernel;
743 };
744
compute_q8add_contiguous(const struct q8add_contiguous_context context[RESTRICT_STATIC1],size_t offset,size_t size)745 static void compute_q8add_contiguous(
746 const struct q8add_contiguous_context context[RESTRICT_STATIC 1],
747 size_t offset,
748 size_t size) {
749 const void* a = (const void*)((uintptr_t)context->a + offset);
750 const void* b = (const void*)((uintptr_t)context->b + offset);
751 void* y = (void*)((uintptr_t)context->y + offset);
752 context->ukernel(size, a, b, y, &context->quantization_params);
753 }
754
755 struct channel_shuffle_context {
756 const void* x;
757 size_t x_stride;
758 void* y;
759 size_t y_stride;
760 size_t n;
761 size_t m;
762 union {
763 pytorch_xzipc_ukernel_function fixed_ukernel;
764 pytorch_xzipv_ukernel_function variable_ukernel;
765 };
766 };
767
compute_channel_shuffle_fixed(const struct channel_shuffle_context context[RESTRICT_STATIC1],size_t index)768 static void compute_channel_shuffle_fixed(
769 const struct channel_shuffle_context context[RESTRICT_STATIC 1],
770 size_t index) {
771 const void* x =
772 (const void*)((uintptr_t)context->x + index * context->x_stride);
773 void* y = (void*)((uintptr_t)context->y + index * context->y_stride);
774
775 context->fixed_ukernel(context->n, x, y);
776 }
777
compute_channel_shuffle_variable(const struct channel_shuffle_context context[RESTRICT_STATIC1],size_t index)778 static void compute_channel_shuffle_variable(
779 const struct channel_shuffle_context context[RESTRICT_STATIC 1],
780 size_t index) {
781 const void* x =
782 (const void*)((uintptr_t)context->x + index * context->x_stride);
783 void* y = (void*)((uintptr_t)context->y + index * context->y_stride);
784
785 context->variable_ukernel(context->n, context->m, x, y);
786 }
787
788 struct lut_strided_context {
789 size_t n;
790 const void* x;
791 size_t x_stride;
792 const void* t;
793 void* y;
794 size_t y_stride;
795 pytorch_x8lut_ukernel_function ukernel;
796 };
797
compute_lut_strided(const struct lut_strided_context context[RESTRICT_STATIC1],size_t batch_index)798 static void compute_lut_strided(
799 const struct lut_strided_context context[RESTRICT_STATIC 1],
800 size_t batch_index) {
801 const void* x =
802 (const void*)((uintptr_t)context->x + context->x_stride * batch_index);
803 void* y = (void*)((uintptr_t)context->y + context->y_stride * batch_index);
804
805 context->ukernel(context->n, x, context->t, y);
806 }
807
808 struct lut_contiguous_context {
809 const void* x;
810 size_t x_stride;
811 const void* t;
812 void* y;
813 size_t y_stride;
814 pytorch_x8lut_ukernel_function ukernel;
815 };
816
compute_lut_contiguous(const struct lut_contiguous_context context[RESTRICT_STATIC1],size_t offset,size_t size)817 static void compute_lut_contiguous(
818 const struct lut_contiguous_context context[RESTRICT_STATIC 1],
819 size_t offset,
820 size_t size) {
821 const void* x = (const void*)((uintptr_t)context->x + offset);
822 void* y = (void*)((uintptr_t)context->y + offset);
823
824 context->ukernel(size, x, context->t, y);
825 }
826
827 struct clamp_strided_context {
828 size_t n;
829 const void* x;
830 size_t x_stride;
831 void* y;
832 size_t y_stride;
833 pytorch_u8clamp_ukernel_function ukernel;
834 union pytorch_qnnp_u8_clamping_params params;
835 };
836
compute_clamp_strided(const struct clamp_strided_context context[RESTRICT_STATIC1],size_t batch_index)837 static void compute_clamp_strided(
838 const struct clamp_strided_context context[RESTRICT_STATIC 1],
839 size_t batch_index) {
840 const void* x =
841 (const void*)((uintptr_t)context->x + context->x_stride * batch_index);
842 void* y = (void*)((uintptr_t)context->y + context->y_stride * batch_index);
843 context->ukernel(context->n, x, y, &context->params);
844 }
845
846 struct clamp_contiguous_context {
847 const void* x;
848 size_t x_stride;
849 void* y;
850 size_t y_stride;
851 pytorch_u8clamp_ukernel_function ukernel;
852 union pytorch_qnnp_u8_clamping_params params;
853 };
854
compute_clamp_contiguous(const struct clamp_contiguous_context context[RESTRICT_STATIC1],size_t offset,size_t size)855 static void compute_clamp_contiguous(
856 const struct clamp_contiguous_context context[RESTRICT_STATIC 1],
857 size_t offset,
858 size_t size) {
859 const void* x = (const void*)((uintptr_t)context->x + offset);
860 void* y = (void*)((uintptr_t)context->y + offset);
861 context->ukernel(size, x, y, &context->params);
862 }
863
864 struct u8softargmax_context {
865 size_t n;
866 const uint8_t* x;
867 size_t x_stride;
868 const uint32_t* t;
869 uint8_t* y;
870 size_t y_stride;
871 pytorch_u8rmax_ukernel_function rmax_ukernel;
872 pytorch_u8lut32norm_ukernel_function lut_norm_ukernel;
873 };
874
compute_u8softargmax(const struct u8softargmax_context context[RESTRICT_STATIC1],size_t batch_index)875 static void compute_u8softargmax(
876 const struct u8softargmax_context context[RESTRICT_STATIC 1],
877 size_t batch_index) {
878 const uint8_t* x =
879 (const uint8_t*)((uintptr_t)context->x + context->x_stride * batch_index);
880 uint8_t* y =
881 (uint8_t*)((uintptr_t)context->y + context->y_stride * batch_index);
882 const size_t n = context->n;
883
884 const uint8_t x_max = context->rmax_ukernel(n, x);
885 const size_t adjustment = x_max ^ 255;
886 const uint32_t* t = (const uint32_t*)context->t + adjustment;
887 context->lut_norm_ukernel(n, x, t, y);
888 }
889
pytorch_qnnp_run_operator(pytorch_qnnp_operator_t op,pthreadpool_t threadpool)890 enum pytorch_qnnp_status pytorch_qnnp_run_operator(
891 pytorch_qnnp_operator_t op,
892 pthreadpool_t threadpool) {
893 // For any ukernel type, there is no work to do if the batch size is 0.
894 if (op->batch_size == 0) {
895 return pytorch_qnnp_status_success;
896 }
897
898 switch (op->ukernel_type) {
899 case pytorch_qnnp_ukernel_type_dwconv: {
900 const size_t batch_size = op->batch_size;
901 const size_t groups = op->groups;
902 const size_t kernel_depth = op->kernel_depth;
903 const size_t kernel_height = op->kernel_height;
904 const size_t kernel_width = op->kernel_width;
905 const size_t kernel_size = kernel_depth * kernel_height * kernel_width;
906 const size_t output_depth = op->output_depth;
907 const size_t output_height = op->output_height;
908 const size_t output_width = op->output_width;
909
910 const size_t step_height = op->step_height;
911 const size_t step_width = op->step_width;
912
913 switch (kernel_size) {
914 case 9: {
915 struct q8dwconv2d_context context = {
916 .groups = groups,
917 .indirection_buffer = (const uint8_t**)op->indirection_buffer,
918 .indirection_buffer_row_stride = step_height,
919 .indirection_buffer_col_stride =
920 kernel_height * step_width * sizeof(void*),
921 .packed_weights = op->packed_weights,
922 .output = op->output,
923 .output_height = output_height,
924 .output_width = output_width,
925 .output_row_stride = output_width * op->output_pixel_stride,
926 .output_col_increment =
927 (op->output_pixel_stride - groups) * sizeof(uint8_t),
928 .quantization_params = op->conv_quantization_params,
929 .unipass_ukernel = op->per_channel
930 ? pytorch_qnnp_params.q8dw9.updw_per_channel
931 : pytorch_qnnp_params.q8dw9.updw,
932 };
933 pthreadpool_compute_2d(
934 threadpool,
935 (pthreadpool_function_2d_t)compute_dwconv2d_unipass,
936 &context,
937 batch_size,
938 output_height);
939 break;
940 }
941 case 25: {
942 struct q8dwconv2d_context context = {
943 .groups = groups,
944 .group_stride = op->group_stride,
945 .indirection_buffer = (const uint8_t**)op->indirection_buffer,
946 .indirection_buffer_row_stride = step_height,
947 .indirection_buffer_col_stride =
948 kernel_height * step_width * sizeof(void*),
949 .packed_weights = op->packed_weights,
950 .output = op->output,
951 .output_height = output_height,
952 .output_width = output_width,
953 .output_row_stride = output_width * op->output_pixel_stride,
954 .output_col_increment =
955 (op->output_pixel_stride - groups) * sizeof(uint8_t),
956 .quantization_params = op->conv_quantization_params,
957 .multipass_ukernel = op->per_channel
958 ? pytorch_qnnp_params.q8dw25.mpdw_per_channel
959 : pytorch_qnnp_params.q8dw25.mpdw,
960 };
961 pthreadpool_compute_2d(
962 threadpool,
963 (pthreadpool_function_2d_t)compute_dwconv2d_multiipass,
964 &context,
965 batch_size,
966 output_height);
967 break;
968 }
969 case 27: {
970 struct q8dwconv3d_context context = {
971 .groups = groups,
972 .group_stride = op->group_stride,
973 .indirection_buffer = (const uint8_t**)op->indirection_buffer,
974 .indirection_buffer_slice_stride = step_height * output_height,
975 .indirection_buffer_row_stride = step_height * sizeof(void*),
976 .indirection_buffer_col_stride =
977 kernel_height * kernel_depth * step_width * sizeof(void*),
978 .packed_weights = op->packed_weights,
979 .output = op->output,
980 .output_depth = output_depth,
981 .output_height = output_height,
982 .output_width = output_width,
983 .output_slice_stride =
984 output_height * output_width * op->output_pixel_stride,
985 .quantization_params = op->conv_quantization_params,
986 .multipass_ukernel = pytorch_qnnp_params.q8dw27.mpdw,
987 };
988 pthreadpool_compute_2d(
989 threadpool,
990 (pthreadpool_function_2d_t)compute_dwconv3d_multiipass,
991 &context,
992 batch_size,
993 output_depth);
994 break;
995 }
996 default:
997 PYTORCH_QNNP_UNREACHABLE;
998 }
999 break;
1000 }
1001 case pytorch_qnnp_ukernel_type_xzp_gemm: {
1002 const size_t batch_size = op->batch_size;
1003 const size_t groups = op->groups;
1004 const size_t group_input_channels = op->group_input_channels;
1005 const size_t group_output_channels = op->group_output_channels;
1006 const uint32_t mr = pytorch_qnnp_params.q8conv_xzp.mr;
1007 const uint32_t nr = pytorch_qnnp_params.q8conv_xzp.nr;
1008 const uint32_t kr = pytorch_qnnp_params.q8conv_xzp.kr;
1009 const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
1010 const size_t n_stride = (group_output_channels + (nr - 1)) & -nr;
1011
1012 /* compute input row sum */
1013 const size_t input_size = op->input_height * op->input_width;
1014 int32_t* a_sum = (int32_t*)op->a_sum;
1015
1016 struct q8sum_rows_context context = {
1017 .a = op->input,
1018 .groups = groups,
1019 .m = input_size,
1020 .k = group_input_channels,
1021 .a_stride = op->input_pixel_stride,
1022 .multiplier = (int32_t)-op->kernel_zero_point,
1023 .a_sum = a_sum,
1024 .a_sum_stride = input_size,
1025 .ukernel = pytorch_qnnp_params.q8sum_rows.sum_rows,
1026 };
1027 pthreadpool_compute_3d_tiled(
1028 threadpool,
1029 (pthreadpool_function_3d_tiled_t)compute_sum_rows,
1030 &context,
1031 groups,
1032 batch_size,
1033 input_size,
1034 1,
1035 1,
1036 pytorch_qnnp_params.q8sum_rows.m);
1037
1038 struct q8gemm_xzp_context q8gemm_xzp_context = {
1039 .k = group_input_channels,
1040 .k_stride = k_stride,
1041 .n = group_output_channels,
1042 .n_stride = n_stride,
1043 .a = op->input,
1044 .a_stride = op->input_pixel_stride,
1045 .packed_w = op->packed_weights,
1046 .c = op->output,
1047 .c_stride = op->output_pixel_stride,
1048 .a_sum = a_sum,
1049 .groups = op->groups,
1050 .batch_size = batch_size,
1051 .a_sum_stride = input_size,
1052 .requantization_params = op->requantization_params,
1053 .ukernel = pytorch_qnnp_params.q8conv_xzp.gemm,
1054 };
1055 pthreadpool_compute_4d_tiled(
1056 threadpool,
1057 (pthreadpool_function_4d_tiled_t)compute_q8gemm_xzp,
1058 &q8gemm_xzp_context,
1059 groups,
1060 batch_size * input_size,
1061 input_size,
1062 group_output_channels,
1063 1,
1064 input_size,
1065 mr,
1066 nr);
1067 break;
1068 }
1069 case pytorch_qnnp_ukernel_type_gemm: {
1070 const size_t batch_size = op->batch_size;
1071 const size_t groups = op->groups;
1072 const size_t group_input_channels = op->group_input_channels;
1073 const size_t group_output_channels = op->group_output_channels;
1074 const uint32_t mr = pytorch_qnnp_params.q8conv.mr;
1075 const uint32_t nr = pytorch_qnnp_params.q8conv.nr;
1076 const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
1077 const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
1078 const size_t n_stride = (group_output_channels + (nr - 1)) & -nr;
1079 const size_t output_depth = op->output_depth;
1080 const size_t output_size = (output_depth != 0 ? output_depth : 1) *
1081 op->output_height * op->output_width;
1082
1083 struct q8gemm_context q8gemm_context = {
1084 .k = group_input_channels,
1085 .k_stride = k_stride,
1086 .n = group_output_channels,
1087 .n_stride = n_stride,
1088 .a = op->input,
1089 .a_stride = op->input_pixel_stride,
1090 .packed_w = op->packed_weights,
1091 .c = op->output,
1092 .c_stride = op->output_pixel_stride,
1093 .quantization_params = op->conv_quantization_params,
1094 .ukernel = pytorch_qnnp_params.q8conv.gemm,
1095 };
1096
1097 pthreadpool_compute_4d_tiled(
1098 threadpool,
1099 (pthreadpool_function_4d_tiled_t)compute_q8gemm,
1100 &q8gemm_context,
1101 groups,
1102 batch_size * output_size,
1103 output_size,
1104 group_output_channels,
1105 1,
1106 output_size,
1107 mr,
1108 nr);
1109 break;
1110 }
1111 #ifdef NO_PREPACK_SPARSE_KERNEL
1112 case pytorch_qnnp_ukernel_type_gemm_sparse_dq: {
1113 const size_t batch_size = op->batch_size;
1114 const size_t groups = op->groups;
1115 const size_t group_output_channels = op->group_output_channels;
1116 const uint32_t mr = pytorch_qnnp_params.q8gemm_sparse_c1x4.mr;
1117 const uint32_t nr = pytorch_qnnp_params.q8gemm_sparse_c1x4.nr;
1118
1119 const size_t output_size = op->output_height * op->output_width;
1120 struct q8gemm_sparse_dq_context q8gemm_sparse_dq_context = {
1121 .a = op->input,
1122 .a_stride = op->input_pixel_stride,
1123 .kernel_col_indices = op->sparse_matrix.col_indices,
1124 .kernel_row_values = op->sparse_matrix.row_values,
1125 .kernel_values = op->sparse_matrix.values,
1126 .bias = (const float*)op->bias,
1127 .c = (float*)op->output,
1128 .c_stride = op->output_pixel_stride,
1129 .quantization_params = op->dynamic_conv_quantization_params,
1130 .ukernel = pytorch_qnnp_params.q8gemm_sparse_c1x4.gemm_dq,
1131 };
1132
1133 pthreadpool_compute_4d_tiled(
1134 threadpool,
1135 (pthreadpool_function_4d_tiled_t)compute_q8gemm_sparse_dq,
1136 &q8gemm_sparse_dq_context,
1137 groups,
1138 batch_size * output_size,
1139 output_size,
1140 group_output_channels,
1141 1,
1142 output_size,
1143 mr,
1144 nr);
1145 break;
1146 }
1147 #endif
1148 case pytorch_qnnp_ukernel_type_gemm_prepackA_sparse_dq: {
1149 const size_t batch_size = op->batch_size;
1150 const size_t groups = op->groups;
1151 const size_t group_input_channels = op->group_input_channels;
1152 const size_t group_output_channels = op->group_output_channels;
1153 uint32_t mr, log2_mr, nr, kr, log2_row_block_size;
1154 pytorch_q8gemm_sparse_packA_ukernel_function prepack_kernel;
1155 struct pytorch_q8gemm_sparse_parameters* pytorch_q8gemm_sparse_params =
1156 NULL; // used to assign ukernel
1157 if (op->sparse_matrix.row_block_size == 1 &&
1158 op->sparse_matrix.col_block_size == 4) {
1159 mr = pytorch_qnnp_params.q8gemm_sparse_c1x4.mr;
1160 log2_mr = pytorch_qnnp_params.q8gemm_sparse_c1x4.log2_mr;
1161 log2_row_block_size = 0;
1162 nr = pytorch_qnnp_params.q8gemm_sparse_c1x4.nr;
1163 kr = pytorch_qnnp_params.q8gemm_sparse_c1x4.kr;
1164 prepack_kernel = pytorch_qnnp_params.q8gemm_sparse_c1x4.packA;
1165 pytorch_q8gemm_sparse_params = &pytorch_qnnp_params.q8gemm_sparse_c1x4;
1166 } else if (op->sparse_matrix.row_block_size == 8 &&
1167 op->sparse_matrix.col_block_size == 1) {
1168 mr = pytorch_qnnp_params.q8gemm_sparse_c8x1.mr;
1169 log2_mr = pytorch_qnnp_params.q8gemm_sparse_c8x1.log2_mr;
1170 log2_row_block_size = 3;
1171 nr = pytorch_qnnp_params.q8gemm_sparse_c8x1.nr;
1172 kr = pytorch_qnnp_params.q8gemm_sparse_c8x1.kr;
1173 prepack_kernel = pytorch_qnnp_params.q8gemm_sparse_c8x1.packA;
1174 pytorch_q8gemm_sparse_params = &pytorch_qnnp_params.q8gemm_sparse_c8x1;
1175 } else {
1176 return pytorch_qnnp_status_invalid_parameter;
1177 }
1178
1179 const size_t output_size = op->output_height * op->output_width;
1180 const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
1181 const size_t m_stride = (output_size + (mr - 1)) & -mr;
1182 op->prepacked_a =
1183 (uint8_t*)realloc((void*)op->prepacked_a, k_stride * m_stride);
1184 if (op->prepacked_a == NULL) {
1185 pytorch_qnnp_log_error(
1186 "failed to allocate %zu bytes for packed activation buffer",
1187 (k_stride * m_stride));
1188 return pytorch_qnnp_status_out_of_memory;
1189 }
1190
1191 struct q8gemm_prepackA_sparse_dq_context
1192 q8gemm_prepack_sparse_dq_context = {
1193 .k = group_input_channels,
1194 .a = op->input,
1195 .a_stride = op->input_pixel_stride,
1196 .a_packed = op->prepacked_a,
1197 .a_packed_stride = k_stride * mr,
1198 .log2_mr = log2_mr,
1199 .log2_row_block_size = log2_row_block_size,
1200 .kernel_indices_dtype = op->sparse_matrix.indices_dtype,
1201 .kernel_values = op->sparse_matrix.values,
1202 .bias = (const float*)op->bias,
1203 .c = (float*)op->output,
1204 .c_stride = op->output_pixel_stride,
1205 .quantization_params = op->dynamic_conv_quantization_params,
1206 .prepack_ukernel = prepack_kernel,
1207 // kernel_col_indices, kernel_row_values, and ukernel assigned
1208 // below
1209 };
1210
1211 switch (q8gemm_prepack_sparse_dq_context.kernel_indices_dtype) {
1212 case pytorch_qnnp_sparse_matrix_indices_dtype_uint32_t:
1213 q8gemm_prepack_sparse_dq_context.kernel_col_indices_w32 =
1214 op->sparse_matrix.col_indices_w32;
1215 q8gemm_prepack_sparse_dq_context.kernel_row_values_w32 =
1216 op->sparse_matrix.row_values_w32;
1217 q8gemm_prepack_sparse_dq_context.ukernel_w32 =
1218 pytorch_q8gemm_sparse_params->packedA_w32_gemm_dq;
1219 break;
1220 case pytorch_qnnp_sparse_matrix_indices_dtype_uint16_t:
1221 q8gemm_prepack_sparse_dq_context.kernel_col_indices_w16 =
1222 op->sparse_matrix.col_indices_w16;
1223 q8gemm_prepack_sparse_dq_context.kernel_row_values_w16 =
1224 op->sparse_matrix.row_values_w16;
1225 q8gemm_prepack_sparse_dq_context.ukernel_w16 =
1226 pytorch_q8gemm_sparse_params->packedA_w16_gemm_dq;
1227 break;
1228 case pytorch_qnnp_sparse_matrix_indices_dtype_uint8_t:
1229 q8gemm_prepack_sparse_dq_context.kernel_col_indices_w8 =
1230 op->sparse_matrix.col_indices_w8;
1231 q8gemm_prepack_sparse_dq_context.kernel_row_values_w8 =
1232 op->sparse_matrix.row_values_w8;
1233 q8gemm_prepack_sparse_dq_context.ukernel_w8 =
1234 pytorch_q8gemm_sparse_params->packedA_w8_gemm_dq;
1235 break;
1236 case pytorch_qnnp_sparse_matrix_indices_dtype_invalid:
1237 // Catch invalid index type and return early.
1238 // This ensures all subsequent calls will have a valid index type.
1239 pytorch_qnnp_log_error(
1240 "Invalid indices dtype specified for "
1241 "operator-run pytorch_qnnp_ukernel_type_gemm_prepackA_sparse_dq");
1242 return pytorch_qnnp_status_invalid_parameter;
1243 }
1244
1245 // This batch size is not the actual batch size of the op
1246 // The batch size is modified in fully-connected-sparse.c
1247 if (groups != 1 || batch_size != 1) {
1248 pytorch_qnnp_log_error("pytorch_qnnp_ukernel_type_gemm_prepackA_sparse_dq "
1249 "works with group size = 1, batch_size = 1.\n");
1250 return pytorch_qnnp_status_invalid_parameter;
1251 }
1252
1253 pthreadpool_compute_4d_tiled(
1254 threadpool,
1255 (pthreadpool_function_4d_tiled_t)compute_q8gemm_prepack_a_sparse,
1256 &q8gemm_prepack_sparse_dq_context,
1257 1,
1258 1,
1259 output_size,
1260 1,
1261 1,
1262 1,
1263 mr,
1264 1);
1265
1266 pthreadpool_compute_4d_tiled(
1267 threadpool,
1268 (pthreadpool_function_4d_tiled_t)compute_q8gemm_prepacked_sparse_dq,
1269 &q8gemm_prepack_sparse_dq_context,
1270 groups,
1271 batch_size * output_size,
1272 output_size,
1273 group_output_channels,
1274 1,
1275 output_size,
1276 mr,
1277 nr);
1278 break;
1279 }
1280 case pytorch_qnnp_ukernel_type_conv: {
1281 const size_t batch_size = op->batch_size;
1282 const size_t groups = op->groups;
1283 const size_t group_input_channels = op->group_input_channels;
1284 const size_t group_output_channels = op->group_output_channels;
1285 const uint32_t mr = pytorch_qnnp_params.q8conv.mr;
1286 const uint32_t nr = pytorch_qnnp_params.q8conv.nr;
1287 const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
1288 const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
1289 const size_t n_stride = (group_output_channels + (nr - 1)) & -nr;
1290 const size_t output_depth = op->output_depth;
1291 const size_t output_size = (output_depth != 0 ? output_depth : 1) *
1292 op->output_height * op->output_width;
1293 const size_t kernel_depth = op->kernel_depth;
1294 const size_t kernel_size = (kernel_depth != 0 ? kernel_depth : 1) *
1295 op->kernel_height * op->kernel_width;
1296 const size_t m_stride = round_up(output_size, mr);
1297
1298 struct q8conv_context q8conv_context = {
1299 .bs = batch_size,
1300 .ks = kernel_size,
1301 .kc = group_input_channels,
1302 .kc_stride = k_stride * kernel_size,
1303 .m = output_size,
1304 .m_stride = m_stride,
1305 .n = group_output_channels,
1306 .n_stride = n_stride,
1307 .indirect_a = (const uint8_t**)op->indirection_buffer,
1308 .packed_w = op->packed_weights,
1309 .c = op->output,
1310 .c_stride = op->output_pixel_stride,
1311 .quantization_params = op->conv_quantization_params,
1312 .ukernel = pytorch_qnnp_params.q8conv.conv,
1313 };
1314
1315 pthreadpool_compute_4d_tiled(
1316 threadpool,
1317 (pthreadpool_function_4d_tiled_t)compute_q8conv,
1318 &q8conv_context,
1319 groups,
1320 batch_size,
1321 output_size,
1322 group_output_channels,
1323 1,
1324 1,
1325 mr,
1326 nr);
1327 break;
1328 }
1329 case pytorch_qnnp_ukernel_type_average_pooling: {
1330 const uint32_t kr = pytorch_qnnp_params.q8avgpool.kr;
1331 const uint32_t mr = pytorch_qnnp_params.q8avgpool.mr;
1332 const uint32_t qr = pytorch_qnnp_params.q8avgpool.qr;
1333 const size_t channels = op->channels;
1334 const size_t output_width = op->output_width;
1335 const size_t output_height = op->output_height;
1336 const size_t pooling_height = op->kernel_height;
1337 const size_t pooling_width = op->kernel_width;
1338 const size_t pooling_size = pooling_height * pooling_width;
1339
1340 const size_t indirect_input_height_stride =
1341 op->step_height * sizeof(void*);
1342 const size_t output_height_stride =
1343 output_width * op->output_pixel_stride;
1344
1345 size_t multipass_adjustment = 0;
1346 if (channels >= kr && pooling_size > mr) {
1347 multipass_adjustment = round_up(pooling_size - mr, qr) + mr - qr;
1348 }
1349 struct average_pooling_context context = {
1350 .indirect_input = op->indirection_buffer,
1351 .indirect_input_batch_stride =
1352 output_height * indirect_input_height_stride,
1353 .indirect_input_height_stride = indirect_input_height_stride,
1354 .output = op->output,
1355 .output_batch_stride = output_height * output_height_stride,
1356 .output_height_stride = output_height_stride,
1357 .output_width = output_width,
1358 .pooling_size = pooling_size,
1359 .channels = channels,
1360 .packed_channels = (channels + (kr - 1)) & -kr,
1361 .zero = op->zero_pointer,
1362 .input_increment =
1363 (pooling_height * op->step_width - multipass_adjustment) *
1364 sizeof(void*),
1365 .output_increment =
1366 (op->output_pixel_stride - channels) * sizeof(uint8_t),
1367 .quantization_params = op->avgpool_quantization_params,
1368 };
1369
1370 pthreadpool_function_2d_t compute_function = NULL;
1371 if (channels < kr) {
1372 compute_function =
1373 (pthreadpool_function_2d_t)compute_average_pooling_unipass;
1374 context.unipass_ukernel = pytorch_qnnp_params.q8avgpool.ltkr;
1375 } else {
1376 if (pooling_size <= mr) {
1377 compute_function =
1378 (pthreadpool_function_2d_t)compute_average_pooling_unipass;
1379 context.unipass_ukernel = pytorch_qnnp_params.q8avgpool.gekr_lemr;
1380 } else {
1381 compute_function =
1382 (pthreadpool_function_2d_t)compute_average_pooling_multipass;
1383 context.multipass_ukernel = pytorch_qnnp_params.q8avgpool.gekr_gtmr;
1384 }
1385 }
1386
1387 pthreadpool_compute_2d(
1388 threadpool,
1389 compute_function,
1390 &context,
1391 op->batch_size,
1392 output_height);
1393 break;
1394 }
1395 case pytorch_qnnp_ukernel_type_max_pooling: {
1396 const uint32_t kr = pytorch_qnnp_params.u8maxpool.kr;
1397 const uint32_t mr = pytorch_qnnp_params.u8maxpool.mr;
1398 const uint32_t qr = pytorch_qnnp_params.u8maxpool.qr;
1399 const size_t channels = op->channels;
1400 const size_t output_width = op->output_width;
1401 const size_t output_height = op->output_height;
1402 const size_t pooling_height = op->kernel_height;
1403 const size_t pooling_width = op->kernel_width;
1404 const size_t pooling_size = pooling_height * pooling_width;
1405
1406 const size_t indirect_input_height_stride =
1407 op->step_height * sizeof(void*);
1408 const size_t output_height_stride =
1409 output_width * op->output_pixel_stride;
1410
1411 size_t multipass_adjustment = pooling_size;
1412 if (channels >= kr) {
1413 multipass_adjustment = round_up(doz(pooling_size, mr), qr) + mr;
1414 }
1415 struct max_pooling_context context = {
1416 .indirect_input = op->indirection_buffer,
1417 .indirect_input_batch_stride =
1418 output_height * indirect_input_height_stride,
1419 .indirect_input_height_stride = indirect_input_height_stride,
1420 .output = op->output,
1421 .output_batch_stride = output_height * output_height_stride,
1422 .output_height_stride = output_height_stride,
1423 .output_width = output_width,
1424 .pooling_size = pooling_size,
1425 .channels = channels,
1426 .input_increment =
1427 (pooling_height * op->step_width - multipass_adjustment) *
1428 sizeof(void*),
1429 .output_increment =
1430 (op->output_pixel_stride - channels) * sizeof(uint8_t),
1431 .params = op->u8_clamping_params,
1432 .ukernel = channels < kr ? pytorch_qnnp_params.u8maxpool.ltkr
1433 : pytorch_qnnp_params.u8maxpool.gekr,
1434 };
1435
1436 pthreadpool_compute_2d(
1437 threadpool,
1438 (pthreadpool_function_2d_t)compute_max_pooling,
1439 &context,
1440 op->batch_size,
1441 output_height);
1442 break;
1443 };
1444 case pytorch_qnnp_ukernel_type_add: {
1445 const size_t batch_size = op->batch_size;
1446 const size_t channels = op->channels;
1447 const size_t a_stride = op->input_pixel_stride;
1448 const size_t b_stride = op->input2_pixel_stride;
1449 const size_t y_stride = op->output_pixel_stride;
1450 if ((((a_stride ^ channels) | (b_stride ^ channels) |
1451 (y_stride ^ channels)) == 0) ||
1452 batch_size == 1) {
1453 const size_t block_size = 4096;
1454 struct q8add_contiguous_context add_context = {
1455 .a = op->input,
1456 .b = op->input2,
1457 .y = op->output,
1458 .quantization_params = op->add_quantization_params,
1459 .ukernel = pytorch_qnnp_params.q8vadd,
1460 };
1461 pthreadpool_compute_1d_tiled(
1462 threadpool,
1463 (pthreadpool_function_1d_tiled_t)compute_q8add_contiguous,
1464 &add_context,
1465 batch_size * channels * sizeof(uint8_t),
1466 block_size);
1467 } else {
1468 struct q8add_strided_context add_context = {
1469 .a = op->input,
1470 .a_stride = a_stride * sizeof(uint8_t),
1471 .b = op->input2,
1472 .b_stride = b_stride * sizeof(uint8_t),
1473 .y = op->output,
1474 .y_stride = y_stride * sizeof(uint8_t),
1475 .n = channels,
1476 .quantization_params = op->add_quantization_params,
1477 .ukernel = pytorch_qnnp_params.q8vadd,
1478 };
1479 pthreadpool_compute_1d_tiled(
1480 threadpool,
1481 (pthreadpool_function_1d_tiled_t)compute_q8add_strided,
1482 &add_context,
1483 batch_size,
1484 1);
1485 }
1486 break;
1487 }
1488 case pytorch_qnnp_ukernel_type_global_average_pooling: {
1489 const uint32_t nr = pytorch_qnnp_params.q8gavgpool.nr;
1490 const uint32_t mr = pytorch_qnnp_params.q8gavgpool.mr;
1491 const size_t input_pixel_stride =
1492 op->input_pixel_stride * sizeof(uint8_t);
1493 const size_t input_width = op->input_width;
1494 const size_t channels = op->channels;
1495 struct global_average_pooling_context context = {
1496 .input = op->input,
1497 .zero = op->zero_pointer,
1498 .input_pixel_stride = input_pixel_stride,
1499 .input_batch_stride = input_pixel_stride * input_width,
1500 .input_elements = input_width,
1501 .channels = channels,
1502 .packed_channels = (channels + (nr - 1)) & -nr,
1503 .output = op->output,
1504 .output_batch_stride = op->output_pixel_stride * sizeof(uint8_t),
1505 .quantization_params = op->avgpool_quantization_params,
1506 };
1507 pthreadpool_function_1d_t compute_function = NULL;
1508 if (channels < nr) {
1509 compute_function =
1510 (pthreadpool_function_1d_t)compute_global_average_pooling_unipass;
1511 context.unipass_ukernel = pytorch_qnnp_params.q8gavgpool.ltnr;
1512 } else {
1513 if (input_width <= mr) {
1514 compute_function =
1515 (pthreadpool_function_1d_t)compute_global_average_pooling_unipass;
1516 context.unipass_ukernel = pytorch_qnnp_params.q8gavgpool.genr_lemr;
1517 } else {
1518 compute_function = (pthreadpool_function_1d_t)
1519 compute_global_average_pooling_multipass;
1520 context.multipass_ukernel = pytorch_qnnp_params.q8gavgpool.genr_gtmr;
1521 }
1522 }
1523
1524 pthreadpool_compute_1d(
1525 threadpool, compute_function, &context, op->batch_size);
1526 break;
1527 }
1528 case pytorch_qnnp_ukernel_type_lut: {
1529 const size_t batch_size = op->batch_size;
1530 const size_t channels = op->channels;
1531 const size_t x_stride = op->input_pixel_stride;
1532 const size_t y_stride = op->output_pixel_stride;
1533 if ((((x_stride ^ channels) | (y_stride ^ channels)) == 0) ||
1534 batch_size == 1) {
1535 const size_t block_size = 1024;
1536 struct lut_contiguous_context context = {
1537 .x = op->input,
1538 .x_stride = x_stride * sizeof(uint8_t),
1539 .t = op->lookup_table,
1540 .y = op->output,
1541 .y_stride = y_stride * sizeof(uint8_t),
1542 .ukernel = pytorch_qnnp_params.x8lut,
1543 };
1544 pthreadpool_compute_1d_tiled(
1545 threadpool,
1546 (pthreadpool_function_1d_tiled_t)compute_lut_contiguous,
1547 &context,
1548 batch_size * channels * sizeof(uint8_t),
1549 block_size);
1550 } else {
1551 struct lut_strided_context context = {
1552 .n = channels,
1553 .x = op->input,
1554 .x_stride = x_stride * sizeof(uint8_t),
1555 .t = op->lookup_table,
1556 .y = op->output,
1557 .y_stride = y_stride * sizeof(uint8_t),
1558 .ukernel = pytorch_qnnp_params.x8lut,
1559 };
1560 pthreadpool_compute_1d(
1561 threadpool,
1562 (pthreadpool_function_1d_t)compute_lut_strided,
1563 &context,
1564 batch_size);
1565 }
1566 break;
1567 }
1568 case pytorch_qnnp_ukernel_type_clamp: {
1569 const size_t batch_size = op->batch_size;
1570 const size_t channels = op->channels;
1571 const size_t x_stride = op->input_pixel_stride;
1572 const size_t y_stride = op->output_pixel_stride;
1573 if ((((x_stride ^ channels) | (y_stride ^ channels)) == 0) ||
1574 batch_size == 1) {
1575 const size_t block_size = 4096;
1576 struct clamp_contiguous_context context = {
1577 .x = op->input,
1578 .x_stride = x_stride * sizeof(uint8_t),
1579 .y = op->output,
1580 .y_stride = y_stride * sizeof(uint8_t),
1581 .ukernel = pytorch_qnnp_params.u8clamp,
1582 .params = op->u8_clamping_params,
1583 };
1584 pthreadpool_compute_1d_tiled(
1585 threadpool,
1586 (pthreadpool_function_1d_tiled_t)compute_clamp_contiguous,
1587 &context,
1588 batch_size * channels * sizeof(uint8_t),
1589 block_size);
1590 } else {
1591 struct clamp_strided_context context = {
1592 .n = channels,
1593 .x = op->input,
1594 .x_stride = x_stride * sizeof(uint8_t),
1595 .y = op->output,
1596 .y_stride = y_stride * sizeof(uint8_t),
1597 .ukernel = pytorch_qnnp_params.u8clamp,
1598 .params = op->u8_clamping_params,
1599 };
1600 pthreadpool_compute_1d(
1601 threadpool,
1602 (pthreadpool_function_1d_t)compute_clamp_strided,
1603 &context,
1604 batch_size);
1605 }
1606 break;
1607 }
1608 case pytorch_qnnp_ukernel_type_softargmax: {
1609 struct u8softargmax_context context = {
1610 .n = op->channels,
1611 .x = op->input,
1612 .x_stride = op->input_pixel_stride * sizeof(uint8_t),
1613 .t = op->lookup_table,
1614 .y = op->output,
1615 .y_stride = op->output_pixel_stride * sizeof(uint8_t),
1616 .rmax_ukernel = pytorch_qnnp_params.u8rmax,
1617 .lut_norm_ukernel = pytorch_qnnp_params.u8lut32norm,
1618 };
1619 pthreadpool_compute_1d(
1620 threadpool,
1621 (pthreadpool_function_1d_t)compute_u8softargmax,
1622 &context,
1623 op->batch_size);
1624 break;
1625 }
1626 case pytorch_qnnp_ukernel_type_channel_shuffle: {
1627 const size_t groups = op->groups;
1628 struct channel_shuffle_context channel_shuffle_context = {
1629 .x = op->input,
1630 .x_stride = op->input_pixel_stride * sizeof(uint8_t),
1631 .y = op->output,
1632 .y_stride = op->output_pixel_stride * sizeof(uint8_t),
1633 .n = op->group_channels * sizeof(uint8_t),
1634 .m = groups,
1635 };
1636 pthreadpool_function_1d_t compute_function = NULL;
1637 switch (groups) {
1638 case 2:
1639 compute_function =
1640 (pthreadpool_function_1d_t)compute_channel_shuffle_fixed;
1641 channel_shuffle_context.fixed_ukernel = pytorch_qnnp_params.x8zip.x2;
1642 break;
1643 case 3:
1644 compute_function =
1645 (pthreadpool_function_1d_t)compute_channel_shuffle_fixed;
1646 channel_shuffle_context.fixed_ukernel = pytorch_qnnp_params.x8zip.x3;
1647 break;
1648 case 4:
1649 compute_function =
1650 (pthreadpool_function_1d_t)compute_channel_shuffle_fixed;
1651 channel_shuffle_context.fixed_ukernel = pytorch_qnnp_params.x8zip.x4;
1652 break;
1653 default:
1654 compute_function =
1655 (pthreadpool_function_1d_t)compute_channel_shuffle_variable;
1656 channel_shuffle_context.variable_ukernel =
1657 pytorch_qnnp_params.x8zip.xm;
1658 break;
1659 case 0:
1660 case 1:
1661 PYTORCH_QNNP_UNREACHABLE;
1662 }
1663 pthreadpool_compute_1d(
1664 threadpool,
1665 compute_function,
1666 &channel_shuffle_context,
1667 op->batch_size);
1668 break;
1669 }
1670 default:
1671 PYTORCH_QNNP_UNREACHABLE;
1672 }
1673 return pytorch_qnnp_status_success;
1674 }
1675