xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/convolution.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <stddef.h>
13 #include <stdint.h>
14 #include <stdlib.h>
15 #include <string.h>
16 
17 #include <fxdiv.h>
18 
19 #include <pytorch_qnnpack.h>
20 #include <qnnpack/common.h>
21 #include <qnnpack/indirection.h>
22 #include <qnnpack/log.h>
23 #include <qnnpack/math.h>
24 #include <qnnpack/operator.h>
25 #include <qnnpack/pack.h>
26 #include <qnnpack/params.h>
27 
compute_output_dimension(size_t padded_input_dimension,size_t kernel_dimension,size_t dilation_dimension,size_t subsampling_dimension)28 static inline size_t compute_output_dimension(
29     size_t padded_input_dimension,
30     size_t kernel_dimension,
31     size_t dilation_dimension,
32     size_t subsampling_dimension) {
33   const size_t effective_kernel_dimension =
34       (kernel_dimension - 1) * dilation_dimension + 1;
35   return (padded_input_dimension - effective_kernel_dimension) /
36       subsampling_dimension +
37       1;
38 }
39 
40 /**
41  * Not exposed in header file
42  */
pytorch_qnnp_create_convolution_ndhwc_q8(uint32_t input_padding_depth,uint32_t input_padding_height,uint32_t input_padding_width,uint32_t kernel_depth,uint32_t kernel_height,uint32_t kernel_width,uint32_t subsampling_depth,uint32_t subsampling_height,uint32_t subsampling_width,uint32_t dilation_depth,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,uint8_t input_zero_point,const uint8_t * kernel_zero_points,const uint8_t * kernel,const int32_t * bias,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max,uint32_t flags,const float * requantization_scales,bool per_channel,pytorch_qnnp_operator_t * convolution_out,bool is_2d)43 static enum pytorch_qnnp_status pytorch_qnnp_create_convolution_ndhwc_q8(
44     uint32_t input_padding_depth,
45     uint32_t input_padding_height,
46     uint32_t input_padding_width,
47     uint32_t kernel_depth,
48     uint32_t kernel_height,
49     uint32_t kernel_width,
50     uint32_t subsampling_depth,
51     uint32_t subsampling_height,
52     uint32_t subsampling_width,
53     uint32_t dilation_depth,
54     uint32_t dilation_height,
55     uint32_t dilation_width,
56     uint32_t groups,
57     size_t group_input_channels,
58     size_t group_output_channels,
59     uint8_t input_zero_point,
60     const uint8_t* kernel_zero_points,
61     const uint8_t* kernel,
62     const int32_t* bias,
63     uint8_t output_zero_point,
64     uint8_t output_min,
65     uint8_t output_max,
66     uint32_t flags,
67     const float* requantization_scales,
68     bool per_channel,
69     pytorch_qnnp_operator_t* convolution_out,
70     bool is_2d /* true: 2d, false: 3d */) {
71   pytorch_qnnp_operator_t convolution = NULL;
72   enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized;
73 
74   if (!pytorch_qnnp_params.initialized) {
75     pytorch_qnnp_log_error(
76         "pytorch_qnnp_create_convolution2d_nhwc_q8 failed because QNNPACK is not properly initialized");
77     goto error;
78   }
79 
80   status = pytorch_qnnp_status_invalid_parameter;
81 
82   if (kernel_width == 0 || kernel_height == 0) {
83     pytorch_qnnp_log_error(
84         "failed to create convolution with %" PRIu32 "x%" PRIu32
85         " kernel: kernel dimensions must be non-zero",
86         kernel_width,
87         kernel_height);
88     goto error;
89   }
90 
91   if (subsampling_width == 0 || subsampling_height == 0) {
92     pytorch_qnnp_log_error(
93         "failed to create convolution with %" PRIu32 "x%" PRIu32
94         " subsampling: "
95         "subsampling dimensions must be non-zero",
96         subsampling_width,
97         subsampling_height);
98     goto error;
99   }
100 
101   if (dilation_width == 0 || dilation_height == 0) {
102     pytorch_qnnp_log_error(
103         "failed to create convolution with %" PRIu32 "x%" PRIu32
104         " dilation: "
105         "dilation dimensions must be non-zero",
106         dilation_width,
107         dilation_height);
108     goto error;
109   }
110 
111   status = pytorch_qnnp_status_unsupported_parameter;
112 
113   if (subsampling_height > kernel_height) {
114     pytorch_qnnp_log_info(
115         "inefficiency in convolution with %" PRIu32 "x%" PRIu32
116         " kernel and %" PRIu32 "x%" PRIu32
117         " subsampling: "
118         "height subsampling is greater than kernel height; subsampling should be performed before the convolution",
119         kernel_width,
120         kernel_height,
121         subsampling_width,
122         subsampling_height);
123   }
124 
125   if (subsampling_width > kernel_width) {
126     pytorch_qnnp_log_info(
127         "inefficiency in convolution with %" PRIu32 "x%" PRIu32
128         " kernel and %" PRIu32 "x%" PRIu32
129         " subsampling: "
130         "width subsampling is greater than kernel width; subsampling should be performed before the convolution",
131         kernel_width,
132         kernel_height,
133         subsampling_width,
134         subsampling_height);
135   }
136 
137   if (input_padding_depth >= kernel_depth) {
138     pytorch_qnnp_log_info(
139         "inefficiency in convolution with %" PRIu32 "x%" PRIu32 "x%" PRIu32
140         " kernel and %" PRIu32 "+%" PRIu32
141         " depth padding: "
142         "input depth padding is greater or equal to kernel depth",
143         kernel_depth,
144         kernel_height,
145         kernel_width,
146         input_padding_depth,
147         input_padding_depth);
148   }
149 
150   if (input_padding_height >= kernel_height) {
151     pytorch_qnnp_log_info(
152         "inefficiency in convolution with %" PRIu32 "x%" PRIu32 "x%" PRIu32
153         " kernel and %" PRIu32 "+%" PRIu32
154         " height padding: "
155         "input height padding is greater or equal to kernel height",
156         kernel_depth,
157         kernel_height,
158         kernel_width,
159         input_padding_height,
160         input_padding_height);
161   }
162 
163   if (input_padding_width >= kernel_width) {
164     pytorch_qnnp_log_info(
165         "inefficiency in convolution with %" PRIu32 "x%" PRIu32 "x%" PRIu32
166         " kernel and %" PRIu32 "+%" PRIu32
167         " width padding: "
168         "input width padding is greater or equal to kernel width",
169         kernel_depth,
170         kernel_height,
171         kernel_width,
172         input_padding_width,
173         input_padding_width);
174   }
175 
176   for (int i = 0; i < groups * group_output_channels; ++i) {
177     if (requantization_scales[i] <= 0.0f ||
178         !isnormal(requantization_scales[i])) {
179       pytorch_qnnp_log_error(
180           "failed to create fully connected operator with %.7g requantization scale: scale must be finite and positive",
181           requantization_scales[i]);
182       goto error;
183     }
184   }
185 
186   status = pytorch_qnnp_status_out_of_memory;
187 
188   convolution = calloc(1, sizeof(struct pytorch_qnnp_operator));
189   if (convolution == NULL) {
190     pytorch_qnnp_log_error(
191         "failed to allocate %zu bytes for pytorch_qnnp_operator structure",
192         sizeof(struct pytorch_qnnp_operator));
193     goto error;
194   }
195 
196   const size_t kernel_size = kernel_height * kernel_width * kernel_depth;
197 
198   enum pytorch_qnnp_ukernel_type ukernel_type = pytorch_qnnp_ukernel_type_none;
199   const bool any_padding =
200       (input_padding_depth | input_padding_height | input_padding_width) != 0;
201 
202   const bool has_depthwise_dimensions =
203       (is_2d &&
204        ((kernel_height == 3 && kernel_width == 3) ||
205         (kernel_height == 5 && kernel_width == 5))) ||
206       (!is_2d && kernel_height == 3 && kernel_width == 3 && kernel_depth == 3);
207   const bool has_depthwise_grouping =
208       group_input_channels == 1 && group_output_channels == 1 && groups > 1;
209   if (has_depthwise_dimensions && has_depthwise_grouping) {
210     ukernel_type = pytorch_qnnp_ukernel_type_dwconv;
211   } else if (
212       kernel_size == 1 && subsampling_height == 1 && subsampling_width == 1 &&
213       !any_padding) {
214     ukernel_type =
215         group_input_channels >= pytorch_qnnp_params.q8conv_xzp.kthreshold
216         ? pytorch_qnnp_ukernel_type_xzp_gemm
217         : pytorch_qnnp_ukernel_type_gemm;
218   } else {
219     ukernel_type = pytorch_qnnp_ukernel_type_conv;
220   }
221   size_t zero_size = 0, zero_offset = 0;
222 
223   switch (ukernel_type) {
224     // This also covers the case of dwconv_per_channel
225     // since the weight packing is shared between the two.
226     case pytorch_qnnp_ukernel_type_dwconv: {
227       const uint32_t cr = pytorch_qnnp_params.q8dw9.cr;
228       const uint32_t c_stride = (groups + (cr - 1)) & -cr;
229       convolution->group_stride = c_stride;
230       const size_t packed_weights_size =
231           (sizeof(uint8_t) * kernel_size + sizeof(int32_t)) * c_stride;
232       convolution->packed_weights = malloc(packed_weights_size);
233       if (convolution->packed_weights == NULL) {
234         pytorch_qnnp_log_error(
235             "failed to allocate %zu bytes for packed weights",
236             packed_weights_size);
237         goto error;
238       }
239 
240       switch (kernel_size) {
241         case 9:
242           pytorch_pack_q8dw_w(
243               kernel_height,
244               kernel_width,
245               groups,
246               cr,
247 #if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
248               input_zero_point,
249               kernel_zero_points[0],
250 #endif
251               kernel,
252               bias,
253               convolution->packed_weights);
254           break;
255         case 25:
256           /* change this later */
257           pytorch_pack_q8dw_2d_w_dilation(
258               kernel_height,
259               kernel_width,
260               groups,
261               cr,
262               0,
263               kernel_height,
264               0,
265               2,
266               kernel,
267               bias,
268               convolution->packed_weights,
269               true);
270           pytorch_pack_q8dw_2d_w_dilation(
271               kernel_height,
272               kernel_width,
273               groups,
274               cr,
275               0,
276               kernel_height,
277               2,
278               4,
279               kernel,
280               bias,
281               (char*)convolution->packed_weights +
282                   (10 + sizeof(int32_t) / sizeof(uint8_t)) * c_stride,
283               false);
284           pytorch_pack_q8dw_2d_w_dilation(
285               kernel_height,
286               kernel_width,
287               groups,
288               cr,
289               0,
290               kernel_height,
291               4,
292               5,
293               kernel,
294               bias,
295               (char*)convolution->packed_weights +
296                   (20 + sizeof(int32_t) / sizeof(uint8_t)) * c_stride,
297               false);
298           break;
299         case 27:
300           pytorch_pack_q8dw_3d_w_dilation(
301               kernel_depth,
302               kernel_height,
303               kernel_width,
304               groups,
305               cr,
306               0,
307               kernel_depth,
308               0,
309               kernel_height,
310               0,
311               1,
312               kernel,
313               bias,
314               convolution->packed_weights,
315               true);
316           pytorch_pack_q8dw_3d_w_dilation(
317               kernel_depth,
318               kernel_height,
319               kernel_width,
320               groups,
321               cr,
322               0,
323               kernel_depth,
324               0,
325               kernel_height,
326               1,
327               2,
328               kernel,
329               bias,
330               (char*)convolution->packed_weights +
331                   (kernel_depth * kernel_height +
332                    sizeof(int32_t) / sizeof(uint8_t)) *
333                       c_stride,
334               false);
335           pytorch_pack_q8dw_3d_w_dilation(
336               kernel_depth,
337               kernel_height,
338               kernel_width,
339               groups,
340               cr,
341               0,
342               kernel_depth,
343               0,
344               kernel_height,
345               2,
346               3,
347               kernel,
348               bias,
349               (char*)convolution->packed_weights +
350                   (2 * kernel_depth * kernel_height +
351                    sizeof(int32_t) / sizeof(uint8_t)) *
352                       c_stride,
353               false);
354           break;
355         default:
356           PYTORCH_QNNP_UNREACHABLE;
357       }
358 
359       if (groups >= 8) {
360         zero_size = sizeof(uint8_t) * c_stride;
361         zero_offset = 0;
362       } else {
363         zero_size = sizeof(uint8_t) * c_stride + 8;
364         zero_offset = sizeof(uint8_t) * 8;
365       }
366       break;
367     }
368     case pytorch_qnnp_ukernel_type_xzp_gemm: {
369       // TODO: XZP kernels won't be supporting per channel quantization.
370       // For now we dont use XZP kernels anywhere. Probably deprecate it for now
371       // and ressurrect later if needed.
372       const uint32_t nr = pytorch_qnnp_params.q8conv_xzp.nr;
373       const uint32_t kr = pytorch_qnnp_params.q8conv_xzp.kr;
374       const uint32_t sr = pytorch_qnnp_params.q8conv_xzp.kc;
375       const uint32_t n_stride = (group_output_channels + (nr - 1)) & -nr;
376       const uint32_t k_stride = (group_input_channels + (kr - 1)) & -kr;
377 
378       const size_t packed_group_weights_size =
379           (sizeof(uint8_t) * kernel_size * k_stride + sizeof(int32_t)) *
380           n_stride;
381       convolution->packed_weights = malloc(packed_group_weights_size * groups);
382       if (convolution->packed_weights == NULL) {
383         pytorch_qnnp_log_error(
384             "failed to allocate %zu bytes for packed weights",
385             packed_group_weights_size * groups);
386         goto error;
387       }
388       /* The XZP ukernel needs the padding to be 0 */
389       memset(
390           convolution->packed_weights, 0, packed_group_weights_size * groups);
391 
392       for (uint32_t group = 0; group < groups; group++) {
393         pytorch_pack_swizzle_q8gemm_b(
394             group_output_channels,
395             group_input_channels,
396             nr,
397             kr,
398             sr,
399 #if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
400             input_zero_point,
401             kernel_zero_points[0],
402 #endif
403             kernel + group * group_output_channels * group_input_channels,
404             bias + group * group_output_channels,
405             (void*)((uintptr_t)convolution->packed_weights + group * packed_group_weights_size));
406       }
407       break;
408     }
409     case pytorch_qnnp_ukernel_type_gemm:
410     case pytorch_qnnp_ukernel_type_conv: {
411       const uint32_t nr = pytorch_qnnp_params.q8conv.nr;
412       const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
413       const uint32_t n_stride = (group_output_channels + (nr - 1)) & -nr;
414       const uint32_t k_stride = (group_input_channels + (kr - 1)) & -kr;
415 
416       const size_t packed_group_weights_size =
417           (sizeof(uint8_t) * kernel_size * k_stride + sizeof(int32_t)) *
418           n_stride;
419       convolution->packed_weights = malloc(packed_group_weights_size * groups);
420       if (convolution->packed_weights == NULL) {
421         pytorch_qnnp_log_error(
422             "failed to allocate %zu bytes for packed weights",
423             packed_group_weights_size * groups);
424         goto error;
425       }
426       memset(
427           convolution->packed_weights,
428           kernel_zero_points[0],
429           packed_group_weights_size * groups);
430 
431       switch (ukernel_type) {
432         case pytorch_qnnp_ukernel_type_gemm:
433           for (uint32_t group = 0; group < groups; group++) {
434             pytorch_pack_q8gemm_w(
435                 group_output_channels,
436                 group_input_channels,
437                 nr,
438                 nr,
439                 kr,
440 #if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
441                 input_zero_point,
442                 kernel_zero_points[0],
443 #endif
444                 kernel + group * group_output_channels * group_input_channels,
445                 bias + group * group_output_channels,
446 #if PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
447                 kernel_zero_points + group * group_output_channels,
448 #endif
449                 (void*)((uintptr_t)convolution->packed_weights + group * packed_group_weights_size));
450           }
451           break;
452         case pytorch_qnnp_ukernel_type_conv:
453           for (uint32_t group = 0; group < groups; group++) {
454             pytorch_pack_q8conv_w(
455                 group_output_channels,
456                 kernel_size,
457                 group_input_channels,
458                 nr,
459                 kr,
460 #if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
461                 input_zero_point,
462                 kernel_zero_points[0],
463 #endif
464                 kernel +
465                     group * group_output_channels * kernel_size *
466                         group_input_channels,
467                 bias + group * group_output_channels,
468 #if PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
469                 kernel_zero_points + group * group_output_channels,
470 #endif
471                 (void*)((uintptr_t)convolution->packed_weights + group * packed_group_weights_size));
472           }
473           break;
474         default:
475           PYTORCH_QNNP_UNREACHABLE;
476       }
477 
478       if (group_input_channels >= 8) {
479         zero_size = sizeof(uint8_t) * k_stride;
480         zero_offset = 0;
481       } else {
482         zero_size = sizeof(uint8_t) * k_stride + 8;
483         zero_offset = 8;
484       }
485       break;
486     }
487     default:
488       PYTORCH_QNNP_UNREACHABLE;
489   }
490 
491   if (any_padding) {
492     void* zero_buffer = malloc(zero_size);
493     if (zero_buffer == NULL) {
494       pytorch_qnnp_log_error(
495           "failed to allocate %zu bytes for zero padding", zero_size);
496       goto error;
497     }
498     memset(zero_buffer, input_zero_point, zero_size);
499     convolution->zero_buffer = zero_buffer;
500     convolution->zero_pointer = (void*)((uintptr_t)zero_buffer + zero_offset);
501   }
502 
503   convolution->input_padding_depth = input_padding_depth;
504   convolution->input_padding_height = input_padding_height;
505   convolution->input_padding_width = input_padding_width;
506   convolution->kernel_depth = kernel_depth;
507   convolution->kernel_height = kernel_height;
508   convolution->kernel_width = kernel_width;
509   convolution->stride_depth = subsampling_depth;
510   convolution->stride_height = subsampling_height;
511   convolution->stride_width = subsampling_width;
512   convolution->dilation_depth = dilation_depth;
513   convolution->dilation_height = dilation_height;
514   convolution->dilation_width = dilation_width;
515   convolution->groups = groups;
516   convolution->group_input_channels = group_input_channels;
517   convolution->group_output_channels = group_output_channels;
518 
519   convolution->kernel_zero_point = kernel_zero_points[0];
520 
521   if (ukernel_type == pytorch_qnnp_ukernel_type_xzp_gemm) {
522     convolution->requantization_params =
523         pytorch_qnnp_compute_requantization_params(
524             requantization_scales[0], output_zero_point, output_min, output_max);
525   } else {
526     convolution->conv_quantization_params =
527         pytorch_qnnp_compute_conv_quantization_params(
528             input_zero_point,
529             kernel_zero_points,
530             requantization_scales,
531             output_zero_point,
532             output_min,
533             output_max);
534   }
535 
536   convolution->ukernel_type = ukernel_type;
537   convolution->format = pytorch_qnnp_format_quint8;
538 
539   convolution->per_channel = per_channel;
540 
541   *convolution_out = convolution;
542   return pytorch_qnnp_status_success;
543 
544 error:
545   pytorch_qnnp_delete_operator(convolution);
546   return status;
547 }
548 
pytorch_qnnp_create_convolution2d_nhwc_q8(uint32_t input_padding_height,uint32_t input_padding_width,uint32_t kernel_height,uint32_t kernel_width,uint32_t subsampling_height,uint32_t subsampling_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,uint8_t input_zero_point,const uint8_t * kernel_zero_points,const uint8_t * kernel,const int32_t * bias,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max,uint32_t flags,const float * requantization_scales,bool per_channel,pytorch_qnnp_operator_t * convolution_out)549 enum pytorch_qnnp_status pytorch_qnnp_create_convolution2d_nhwc_q8(
550     uint32_t input_padding_height,
551     uint32_t input_padding_width,
552     uint32_t kernel_height,
553     uint32_t kernel_width,
554     uint32_t subsampling_height,
555     uint32_t subsampling_width,
556     uint32_t dilation_height,
557     uint32_t dilation_width,
558     uint32_t groups,
559     size_t group_input_channels,
560     size_t group_output_channels,
561     uint8_t input_zero_point,
562     const uint8_t* kernel_zero_points,
563     const uint8_t* kernel,
564     const int32_t* bias,
565     uint8_t output_zero_point,
566     uint8_t output_min,
567     uint8_t output_max,
568     uint32_t flags,
569     const float* requantization_scales,
570     bool per_channel,
571     pytorch_qnnp_operator_t* convolution_out) {
572   return pytorch_qnnp_create_convolution_ndhwc_q8(
573       0,
574       input_padding_height,
575       input_padding_width,
576       1,
577       kernel_height,
578       kernel_width,
579       1,
580       subsampling_height,
581       subsampling_width,
582       1,
583       dilation_height,
584       dilation_width,
585       groups,
586       group_input_channels,
587       group_output_channels,
588       input_zero_point,
589       kernel_zero_points,
590       kernel,
591       bias,
592       output_zero_point,
593       output_min,
594       output_max,
595       flags,
596       requantization_scales,
597       per_channel,
598       convolution_out,
599       true /* is_2d? */);
600 }
601 
pytorch_qnnp_create_convolution3d_ndhwc_q8(uint32_t input_padding_depth,uint32_t input_padding_height,uint32_t input_padding_width,uint32_t kernel_depth,uint32_t kernel_height,uint32_t kernel_width,uint32_t subsampling_depth,uint32_t subsampling_height,uint32_t subsampling_width,uint32_t dilation_depth,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,uint8_t input_zero_point,const uint8_t * kernel_zero_points,const uint8_t * kernel,const int32_t * bias,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max,uint32_t flags,const float * requantization_scales,bool per_channel,pytorch_qnnp_operator_t * convolution_out)602 enum pytorch_qnnp_status pytorch_qnnp_create_convolution3d_ndhwc_q8(
603     uint32_t input_padding_depth,
604     uint32_t input_padding_height,
605     uint32_t input_padding_width,
606     uint32_t kernel_depth,
607     uint32_t kernel_height,
608     uint32_t kernel_width,
609     uint32_t subsampling_depth,
610     uint32_t subsampling_height,
611     uint32_t subsampling_width,
612     uint32_t dilation_depth,
613     uint32_t dilation_height,
614     uint32_t dilation_width,
615     uint32_t groups,
616     size_t group_input_channels,
617     size_t group_output_channels,
618     uint8_t input_zero_point,
619     const uint8_t* kernel_zero_points,
620     const uint8_t* kernel,
621     const int32_t* bias,
622     uint8_t output_zero_point,
623     uint8_t output_min,
624     uint8_t output_max,
625     uint32_t flags,
626     const float* requantization_scales,
627     bool per_channel,
628     pytorch_qnnp_operator_t* convolution_out) {
629   return pytorch_qnnp_create_convolution_ndhwc_q8(
630       input_padding_depth,
631       input_padding_height,
632       input_padding_width,
633       kernel_depth,
634       kernel_height,
635       kernel_width,
636       subsampling_depth,
637       subsampling_height,
638       subsampling_width,
639       dilation_depth,
640       dilation_height,
641       dilation_width,
642       groups,
643       group_input_channels,
644       group_output_channels,
645       input_zero_point,
646       kernel_zero_points,
647       kernel,
648       bias,
649       output_zero_point,
650       output_min,
651       output_max,
652       flags,
653       requantization_scales,
654       per_channel,
655       convolution_out,
656       false /* is_2d? */);
657 }
658 
pytorch_qnnp_setup_convolution2d_nhwc_q8(pytorch_qnnp_operator_t convolution,size_t batch_size,size_t input_height,size_t input_width,const uint8_t * input,size_t input_pixel_stride,uint8_t * output,size_t output_pixel_stride,pthreadpool_t threadpool)659 enum pytorch_qnnp_status pytorch_qnnp_setup_convolution2d_nhwc_q8(
660     pytorch_qnnp_operator_t convolution,
661     size_t batch_size,
662     size_t input_height,
663     size_t input_width,
664     const uint8_t* input,
665     size_t input_pixel_stride,
666     uint8_t* output,
667     size_t output_pixel_stride,
668     pthreadpool_t threadpool) {
669   return pytorch_qnnp_setup_convolution_ndhwc_q8(
670       convolution,
671       batch_size,
672       1,
673       input_height,
674       input_width,
675       input,
676       input_pixel_stride,
677       output,
678       output_pixel_stride,
679       threadpool);
680 }
681 
pytorch_qnnp_setup_convolution_ndhwc_q8(pytorch_qnnp_operator_t convolution,size_t batch_size,size_t input_depth,size_t input_height,size_t input_width,const uint8_t * input,size_t input_pixel_stride,uint8_t * output,size_t output_pixel_stride,pthreadpool_t threadpool)682 enum pytorch_qnnp_status pytorch_qnnp_setup_convolution_ndhwc_q8(
683     pytorch_qnnp_operator_t convolution,
684     size_t batch_size,
685     size_t input_depth,
686     size_t input_height,
687     size_t input_width,
688     const uint8_t* input,
689     size_t input_pixel_stride,
690     uint8_t* output,
691     size_t output_pixel_stride,
692     pthreadpool_t threadpool) {
693   if (!pytorch_qnnp_params.initialized) {
694     pytorch_qnnp_log_error(
695         "pytorch_qnnp_setup_convolution_ndhwc_q8 failed because QNNPACK is not properly initialized");
696     return pytorch_qnnp_status_uninitialized;
697   }
698 
699   if (batch_size == 0) {
700     convolution->batch_size = 0;
701     return pytorch_qnnp_status_success;
702   }
703 
704   if (input_width == 0 || input_height == 0 || input_depth == 0) {
705     pytorch_qnnp_log_error(
706         "failed to setup convolution with %zux%zux%zu input: input dimensions must be non-zero",
707         input_width,
708         input_height,
709         input_depth);
710     return pytorch_qnnp_status_invalid_parameter;
711   }
712 
713   convolution->batch_size = batch_size;
714   convolution->input_depth = input_depth;
715   convolution->input_height = input_height;
716   convolution->input_width = input_width;
717   convolution->input = input;
718   convolution->input_pixel_stride = input_pixel_stride;
719 
720   convolution->output_depth = compute_output_dimension(
721       input_depth + convolution->input_padding_depth * 2,
722       convolution->kernel_depth,
723       convolution->dilation_depth,
724       convolution->stride_depth);
725   convolution->output_height = compute_output_dimension(
726       input_height + convolution->input_padding_height * 2,
727       convolution->kernel_height,
728       convolution->dilation_height,
729       convolution->stride_height);
730   convolution->output_width = compute_output_dimension(
731       input_width + convolution->input_padding_width * 2,
732       convolution->kernel_width,
733       convolution->dilation_width,
734       convolution->stride_width);
735   convolution->output = output;
736   convolution->output_pixel_stride = output_pixel_stride;
737 
738   switch (convolution->ukernel_type) {
739     case pytorch_qnnp_ukernel_type_gemm:
740       /* Convolution maps directly to GEMM and doesn't use indirection buffer */
741       return pytorch_qnnp_status_success;
742     case pytorch_qnnp_ukernel_type_xzp_gemm: {
743       const size_t groups = convolution->groups;
744       const size_t input_size = input_depth * input_height * input_width;
745       void* a_sum = (void*)realloc(
746           convolution->a_sum,
747           sizeof(int32_t) * batch_size * groups * input_size);
748       if (a_sum == NULL) {
749         pytorch_qnnp_log_error(
750             "failed to allocate %zu bytes for row sum data",
751             sizeof(int32_t) * batch_size * groups * input_size);
752         return pytorch_qnnp_status_out_of_memory;
753       }
754       convolution->a_sum = a_sum;
755       return pytorch_qnnp_status_success;
756     }
757     case pytorch_qnnp_ukernel_type_conv: {
758       const size_t groups = convolution->groups;
759       const size_t kernel_depth = convolution->kernel_depth;
760       const size_t kernel_height = convolution->kernel_height;
761       const size_t kernel_width = convolution->kernel_width;
762       const size_t kernel_size = kernel_depth * kernel_height * kernel_width;
763       const size_t output_depth = convolution->output_depth;
764       const size_t output_height = convolution->output_height;
765       const size_t output_width = convolution->output_width;
766       const size_t output_size = output_depth * output_height * output_width;
767       const size_t output_tile_size = pytorch_qnnp_params.q8conv.mr;
768       const size_t tiled_output_size = round_up(output_size, output_tile_size);
769       const size_t indirection_buffer_size =
770           sizeof(void*) * batch_size * groups * tiled_output_size * kernel_size;
771 
772       const void** indirection_buffer = (const void**)realloc(
773           convolution->indirection_buffer, indirection_buffer_size);
774       if (indirection_buffer == NULL) {
775         pytorch_qnnp_log_error(
776             "failed to allocate %zu bytes for indirection buffer",
777             indirection_buffer_size);
778         return pytorch_qnnp_status_out_of_memory;
779       }
780       convolution->indirection_buffer = indirection_buffer;
781       pytorch_qnnp_indirection_init_conv3d(
782           convolution, output_tile_size, tiled_output_size);
783       return pytorch_qnnp_status_success;
784     }
785     case pytorch_qnnp_ukernel_type_dwconv: {
786       pytorch_qnnp_indirection_set_step_dimensions(convolution);
787 
788       const size_t indirection_buffer_size = sizeof(void*) * batch_size *
789           convolution->output_depth * convolution->step_depth;
790 
791       const void** indirection_buffer = (const void**)realloc(
792           convolution->indirection_buffer, indirection_buffer_size);
793       if (indirection_buffer == NULL) {
794         pytorch_qnnp_log_error(
795             "failed to allocate %zu bytes for indirection buffer",
796             indirection_buffer_size);
797         return pytorch_qnnp_status_out_of_memory;
798       }
799       convolution->indirection_buffer = indirection_buffer;
800 
801       pytorch_qnnp_indirection_init_dwconv(convolution, 0);
802       return pytorch_qnnp_status_success;
803     }
804     default:
805       PYTORCH_QNNP_UNREACHABLE;
806   }
807 }
808