xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/include/pytorch_qnnpack.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <stdbool.h>
12 #include <stddef.h>
13 #include <stdint.h>
14 
15 #include <pthreadpool.h>
16 #include <qnnpack/log.h>
17 
18 #ifdef __cplusplus
19 extern "C" {
20 #endif
21 
22 /**
23  * @brief Status code for any QNNPACK function call.
24  */
25 enum pytorch_qnnp_status {
26   /** The call succeeded, and all output arguments now contain valid data. */
27   pytorch_qnnp_status_success = 0,
28   pytorch_qnnp_status_uninitialized = 1,
29   pytorch_qnnp_status_invalid_parameter = 2,
30   pytorch_qnnp_status_unsupported_parameter = 3,
31   pytorch_qnnp_status_unsupported_hardware = 4,
32   pytorch_qnnp_status_out_of_memory = 5,
33 };
34 
35 enum pytorch_qnnp_sparse_matrix_indices_dtype {
36   pytorch_qnnp_sparse_matrix_indices_dtype_invalid = 0,
37   pytorch_qnnp_sparse_matrix_indices_dtype_uint8_t = 8,
38   pytorch_qnnp_sparse_matrix_indices_dtype_uint16_t = 16,
39   pytorch_qnnp_sparse_matrix_indices_dtype_uint32_t = 32,
40 };
41 
42 enum pytorch_qnnp_status pytorch_qnnp_initialize(void);
43 
44 enum pytorch_qnnp_status pytorch_qnnp_deinitialize(void);
45 
46 typedef struct pytorch_qnnp_operator* pytorch_qnnp_operator_t;
47 
48 enum pytorch_qnnp_status pytorch_qnnp_create_convolution2d_nhwc_q8(
49     uint32_t input_padding_height,
50     uint32_t input_padding_width,
51     uint32_t kernel_height,
52     uint32_t kernel_width,
53     uint32_t subsampling_height,
54     uint32_t subsampling_width,
55     uint32_t dilation_height,
56     uint32_t dilation_width,
57     uint32_t groups,
58     size_t group_input_channels,
59     size_t group_output_channels,
60     uint8_t input_zero_point,
61     const uint8_t* kernel_zero_points,
62     const uint8_t* kernel,
63     const int32_t* bias,
64     uint8_t output_zero_point,
65     uint8_t output_min,
66     uint8_t output_max,
67     uint32_t flags,
68     const float* requantization_scales,
69     bool per_channel,
70     pytorch_qnnp_operator_t* convolution);
71 
72 enum pytorch_qnnp_status pytorch_qnnp_create_convolution3d_ndhwc_q8(
73     uint32_t input_padding_depth,
74     uint32_t input_padding_height,
75     uint32_t input_padding_width,
76     uint32_t kernel_depth,
77     uint32_t kernel_height,
78     uint32_t kernel_width,
79     uint32_t subsampling_depth,
80     uint32_t subsampling_height,
81     uint32_t subsampling_width,
82     uint32_t dilation_depth,
83     uint32_t dilation_height,
84     uint32_t dilation_width,
85     uint32_t groups,
86     size_t group_input_channels,
87     size_t group_output_channels,
88     uint8_t input_zero_point,
89     const uint8_t* kernel_zero_points,
90     const uint8_t* kernel,
91     const int32_t* bias,
92     uint8_t output_zero_point,
93     uint8_t output_min,
94     uint8_t output_max,
95     uint32_t flags,
96     const float* requantization_scales,
97     bool per_channel,
98     pytorch_qnnp_operator_t* convolution);
99 
100 enum pytorch_qnnp_status pytorch_qnnp_setup_convolution2d_nhwc_q8(
101     pytorch_qnnp_operator_t convolution,
102     size_t batch_size,
103     size_t input_height,
104     size_t input_width,
105     const uint8_t* input,
106     size_t input_stride,
107     uint8_t* output,
108     size_t output_stride,
109     pthreadpool_t threadpool);
110 
111 enum pytorch_qnnp_status pytorch_qnnp_setup_convolution_ndhwc_q8(
112     pytorch_qnnp_operator_t convolution,
113     size_t batch_size,
114     size_t input_depth,
115     size_t input_height,
116     size_t input_width,
117     const uint8_t* input,
118     size_t input_stride,
119     uint8_t* output,
120     size_t output_stride,
121     pthreadpool_t threadpool);
122 
123 enum pytorch_qnnp_status pytorch_qnnp_create_deconvolution2d_nhwc_q8(
124     uint32_t input_padding_height,
125     uint32_t input_padding_width,
126     uint32_t adjustment_height,
127     uint32_t adjustment_width,
128     uint32_t kernel_height,
129     uint32_t kernel_width,
130     uint32_t stride_height,
131     uint32_t stride_width,
132     uint32_t dilation_height,
133     uint32_t dilation_width,
134     uint32_t groups,
135     size_t group_input_channels,
136     size_t group_output_channels,
137     uint8_t input_zero_point,
138     const uint8_t* kernel_zero_points,
139     const uint8_t* kernel,
140     const int32_t* bias,
141     uint8_t output_zero_point,
142     uint8_t output_min,
143     uint8_t output_max,
144     uint32_t flags,
145     const float* requantization_scales,
146     pytorch_qnnp_operator_t* deconvolution);
147 
148 enum pytorch_qnnp_status pytorch_qnnp_setup_deconvolution2d_nhwc_q8(
149     pytorch_qnnp_operator_t deconvolution,
150     size_t batch_size,
151     size_t input_height,
152     size_t input_width,
153     const uint8_t* input,
154     size_t input_stride,
155     uint8_t* output,
156     size_t output_stride,
157     pthreadpool_t threadpool);
158 
159 enum pytorch_qnnp_status pytorch_qnnp_create_fully_connected_nc_q8(
160     size_t input_channels,
161     size_t output_channels,
162     uint8_t input_zero_point,
163     const uint8_t* kernel_zero_points,
164     const uint8_t* kernel,
165     const int32_t* bias,
166     uint8_t output_zero_point,
167     uint8_t output_min,
168     uint8_t output_max,
169     uint32_t flags,
170     const float* requantization_scales,
171     pytorch_qnnp_operator_t* fully_connected);
172 
173 enum pytorch_qnnp_status pytorch_qnnp_create_fully_connected_sparse_dq_nc_q8(
174     size_t input_channels,
175     size_t output_channels,
176     uint8_t input_zero_point,
177     const uint8_t* kernel_zero_points,
178     const void* kernel_col_indices,
179     const void* kernel_row_values,
180     const uint8_t* kernel_values,
181     const uint32_t kernel_row_block_size,
182     const uint32_t kernel_col_block_size,
183     enum pytorch_qnnp_sparse_matrix_indices_dtype kernel_indices_dtype,
184     uint8_t output_zero_point,
185     uint8_t output_min,
186     uint8_t output_max,
187     uint32_t flags,
188     const float* requantization_scales,
189     bool use_prepack_kernel,
190     pytorch_qnnp_operator_t* fully_connected);
191 
192 enum pytorch_qnnp_status pytorch_qnnp_setup_fully_connected_nc_q8(
193     pytorch_qnnp_operator_t fully_connected,
194     size_t batch_size,
195     const uint8_t* input,
196     size_t input_stride,
197     uint8_t* output,
198     size_t output_stride);
199 
200 enum pytorch_qnnp_status pytorch_qnnp_setup_fully_connected_sparse_dq_nc_q8(
201     pytorch_qnnp_operator_t fully_connected,
202     size_t batch_size,
203     const uint8_t* input,
204     size_t input_stride,
205     const float* bias,
206     float* output,
207     size_t output_stride);
208 
209 enum pytorch_qnnp_status pytorch_qnnp_create_global_average_pooling_nwc_q8(
210     size_t channels,
211     uint8_t input_zero_point,
212     float input_scale,
213     uint8_t output_zero_point,
214     float output_scale,
215     uint8_t output_min,
216     uint8_t output_max,
217     uint32_t flags,
218     pytorch_qnnp_operator_t* global_average_pooling);
219 
220 enum pytorch_qnnp_status pytorch_qnnp_setup_global_average_pooling_nwc_q8(
221     pytorch_qnnp_operator_t global_average_pooling,
222     size_t batch_size,
223     size_t width,
224     const uint8_t* input,
225     size_t input_stride,
226     uint8_t* output,
227     size_t output_stride);
228 
229 enum pytorch_qnnp_status pytorch_qnnp_create_average_pooling2d_nhwc_q8(
230     uint32_t input_padding_height,
231     uint32_t input_padding_width,
232     uint32_t pooling_height,
233     uint32_t pooling_width,
234     uint32_t stride_height,
235     uint32_t stride_width,
236     size_t channels,
237     uint8_t input_zero_point,
238     float input_scale,
239     uint8_t output_zero_point,
240     float output_scale,
241     uint8_t output_min,
242     uint8_t output_max,
243     uint32_t flags,
244     pytorch_qnnp_operator_t* average_pooling);
245 
246 enum pytorch_qnnp_status pytorch_qnnp_setup_average_pooling2d_nhwc_q8(
247     pytorch_qnnp_operator_t average_pooling,
248     size_t batch_size,
249     size_t input_height,
250     size_t input_width,
251     const uint8_t* input,
252     size_t input_stride,
253     uint8_t* output,
254     size_t output_stride,
255     pthreadpool_t threadpool);
256 
257 enum pytorch_qnnp_status pytorch_qnnp_create_max_pooling2d_nhwc_u8(
258     uint32_t input_padding_height,
259     uint32_t input_padding_width,
260     uint32_t pooling_height,
261     uint32_t pooling_width,
262     uint32_t stride_height,
263     uint32_t stride_width,
264     uint32_t dilation_height,
265     uint32_t dilation_width,
266     size_t channels,
267     uint8_t output_min,
268     uint8_t output_max,
269     uint32_t flags,
270     pytorch_qnnp_operator_t* max_pooling);
271 
272 enum pytorch_qnnp_status pytorch_qnnp_setup_max_pooling2d_nhwc_u8(
273     pytorch_qnnp_operator_t max_pooling,
274     size_t batch_size,
275     size_t input_height,
276     size_t input_width,
277     const uint8_t* input,
278     size_t input_stride,
279     uint8_t* output,
280     size_t output_stride,
281     pthreadpool_t threadpool);
282 
283 enum pytorch_qnnp_status pytorch_qnnp_create_channel_shuffle_nc_x8(
284     size_t groups,
285     size_t group_channels,
286     uint32_t flags,
287     pytorch_qnnp_operator_t* channel_shuffle);
288 
289 enum pytorch_qnnp_status pytorch_qnnp_setup_channel_shuffle_nc_x8(
290     pytorch_qnnp_operator_t channel_shuffle,
291     size_t batch_size,
292     const uint8_t* input,
293     size_t input_stride,
294     uint8_t* output,
295     size_t output_stride);
296 
297 enum pytorch_qnnp_status pytorch_qnnp_create_add_nc_q8(
298     size_t channels,
299     uint8_t a_zero_point,
300     float a_scale,
301     uint8_t b_zero_point,
302     float b_scale,
303     uint8_t sum_zero_point,
304     float sum_scale,
305     uint8_t sum_min,
306     uint8_t sum_max,
307     uint32_t flags,
308     pytorch_qnnp_operator_t* add);
309 
310 enum pytorch_qnnp_status pytorch_qnnp_setup_add_nc_q8(
311     pytorch_qnnp_operator_t add,
312     size_t batch_size,
313     const uint8_t* a,
314     size_t a_stride,
315     const uint8_t* b,
316     size_t b_stride,
317     uint8_t* sum,
318     size_t sum_stride);
319 
320 enum pytorch_qnnp_status pytorch_qnnp_create_clamp_nc_u8(
321     size_t channels,
322     uint8_t output_min,
323     uint8_t output_max,
324     uint32_t flags,
325     pytorch_qnnp_operator_t* clamp);
326 
327 enum pytorch_qnnp_status pytorch_qnnp_setup_clamp_nc_u8(
328     pytorch_qnnp_operator_t clamp,
329     size_t batch_size,
330     const uint8_t* input,
331     size_t input_stride,
332     uint8_t* output,
333     size_t output_stride);
334 
335 enum pytorch_qnnp_status pytorch_qnnp_create_sigmoid_nc_q8(
336     size_t channels,
337     uint8_t input_zero_point,
338     float input_scale,
339     uint8_t output_zero_point,
340     float output_scale,
341     uint8_t output_min,
342     uint8_t output_max,
343     uint32_t flags,
344     pytorch_qnnp_operator_t* sigmoid);
345 
346 enum pytorch_qnnp_status pytorch_qnnp_setup_sigmoid_nc_q8(
347     pytorch_qnnp_operator_t sigmoid,
348     size_t batch_size,
349     const uint8_t* input,
350     size_t input_stride,
351     uint8_t* output,
352     size_t output_stride);
353 
354 enum pytorch_qnnp_status pytorch_qnnp_create_leaky_relu_nc_q8(
355     size_t channels,
356     float negative_slope,
357     uint8_t input_zero_point,
358     float input_scale,
359     uint8_t output_zero_point,
360     float output_scale,
361     uint8_t output_min,
362     uint8_t output_max,
363     uint32_t flags,
364     pytorch_qnnp_operator_t* leaky_relu);
365 
366 enum pytorch_qnnp_status pytorch_qnnp_setup_leaky_relu_nc_q8(
367     pytorch_qnnp_operator_t leaky_relu,
368     size_t batch_size,
369     const uint8_t* input,
370     size_t input_stride,
371     uint8_t* output,
372     size_t output_stride);
373 
374 enum pytorch_qnnp_status pytorch_qnnp_create_softargmax_nc_q8(
375     size_t channels,
376     float input_scale,
377     uint8_t output_zero_point,
378     float output_scale,
379     uint32_t flags,
380     pytorch_qnnp_operator_t* softargmax);
381 
382 enum pytorch_qnnp_status pytorch_qnnp_setup_softargmax_nc_q8(
383     pytorch_qnnp_operator_t softargmax,
384     size_t batch_size,
385     const uint8_t* input,
386     size_t input_stride,
387     uint8_t* output,
388     size_t output_stride);
389 
390 enum pytorch_qnnp_status pytorch_qnnp_create_tanh_nc_q8(
391     size_t channels,
392     uint8_t input_zero_point,
393     float input_scale,
394     uint8_t output_zero_point,
395     float output_scale,
396     uint8_t output_min,
397     uint8_t output_max,
398     uint32_t flags,
399     pytorch_qnnp_operator_t* tanh);
400 
401 enum pytorch_qnnp_status pytorch_qnnp_setup_tanh_nc_q8(
402     pytorch_qnnp_operator_t tanh,
403     size_t batch_size,
404     const uint8_t* input,
405     size_t input_stride,
406     uint8_t* output,
407     size_t output_stride);
408 
409 enum pytorch_qnnp_status pytorch_qnnp_create_hardsigmoid_nc_q8(
410     size_t channels,
411     uint8_t input_zero_point,
412     float input_scale,
413     uint8_t output_zero_point,
414     float output_scale,
415     uint8_t output_min,
416     uint8_t output_max,
417     uint32_t flags,
418     pytorch_qnnp_operator_t* hardsigmoid);
419 
420 enum pytorch_qnnp_status pytorch_qnnp_setup_hardsigmoid_nc_q8(
421     pytorch_qnnp_operator_t hardsigmoid,
422     size_t batch_size,
423     const uint8_t* input,
424     size_t input_stride,
425     uint8_t* output,
426     size_t output_stride);
427 
428 enum pytorch_qnnp_status pytorch_qnnp_create_hardswish_nc_q8(
429     size_t channels,
430     uint8_t input_zero_point,
431     float input_scale,
432     uint8_t output_zero_point,
433     float output_scale,
434     uint8_t output_min,
435     uint8_t output_max,
436     uint32_t flags,
437     pytorch_qnnp_operator_t* hardswish);
438 
439 enum pytorch_qnnp_status pytorch_qnnp_setup_hardswish_nc_q8(
440     pytorch_qnnp_operator_t hardswish,
441     size_t batch_size,
442     const uint8_t* input,
443     size_t input_stride,
444     uint8_t* output,
445     size_t output_stride);
446 
447 enum pytorch_qnnp_status pytorch_qnnp_run_operator(
448     pytorch_qnnp_operator_t op,
449     pthreadpool_t threadpool);
450 
451 enum pytorch_qnnp_status pytorch_qnnp_delete_operator(
452     pytorch_qnnp_operator_t op);
453 
454 #ifdef __cplusplus
455 } /* extern "C" */
456 #endif
457