xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/operator.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <stddef.h>
12 #include <stdint.h>
13 
14 #include <qnnpack/requantization.h>
15 
16 enum pytorch_qnnp_format {
17   pytorch_qnnp_format_quint8 = 0x02000000,
18   pytorch_qnnp_format_float32 = 0x02020202,
19   pytorch_qnnp_format_float16 = 0x01010101,
20 };
21 
22 enum pytorch_qnnp_ukernel_type {
23   pytorch_qnnp_ukernel_type_none = 0,
24   pytorch_qnnp_ukernel_type_add,
25   pytorch_qnnp_ukernel_type_average_pooling,
26   pytorch_qnnp_ukernel_type_channel_shuffle,
27   pytorch_qnnp_ukernel_type_clamp,
28   pytorch_qnnp_ukernel_type_conv,
29   pytorch_qnnp_ukernel_type_dwconv,
30   pytorch_qnnp_ukernel_type_gemm,
31   pytorch_qnnp_ukernel_type_gemm_sparse_dq,
32   pytorch_qnnp_ukernel_type_gemm_prepackA_sparse_dq,
33   pytorch_qnnp_ukernel_type_global_average_pooling,
34   pytorch_qnnp_ukernel_type_lut,
35   pytorch_qnnp_ukernel_type_max_pooling,
36   pytorch_qnnp_ukernel_type_softargmax,
37   pytorch_qnnp_ukernel_type_xzp_gemm,
38 };
39 
40 typedef struct {
41   union {
42     const uint32_t* col_indices_w32;
43     const uint16_t* col_indices_w16;
44     const uint8_t* col_indices_w8;
45   };
46   union {
47     const uint32_t* row_values_w32;
48     const uint16_t* row_values_w16;
49     const uint8_t* row_values_w8;
50   };
51   const uint8_t* values;
52   uint32_t row_block_size;
53   uint32_t col_block_size;
54   enum pytorch_qnnp_sparse_matrix_indices_dtype indices_dtype;
55 } sparse_matrix_t;
56 
57 struct pytorch_qnnp_operator {
58   size_t batch_size;
59   uint32_t input_padding_depth;
60   uint32_t input_padding_height;
61   uint32_t input_padding_width;
62   uint32_t adjustment_height;
63   uint32_t adjustment_width;
64   uint32_t kernel_depth;
65   uint32_t kernel_height;
66   uint32_t kernel_width;
67   uint32_t stride_depth;
68   uint32_t stride_height;
69   uint32_t stride_width;
70   uint32_t dilation_depth;
71   uint32_t dilation_height;
72   uint32_t dilation_width;
73   uint32_t groups;
74   size_t group_stride;
75   size_t group_channels;
76   size_t group_input_channels;
77   size_t group_output_channels;
78   size_t channels;
79 
80   size_t input_depth;
81   size_t input_height;
82   size_t input_width;
83   size_t input_pixel_stride;
84   const void* input;
85   const void** indirection_buffer;
86   void* a_sum;
87 
88   size_t step_depth;
89   size_t step_height;
90   size_t step_width;
91 
92   size_t input2_pixel_stride;
93   const void* input2;
94 
95   size_t output_depth;
96   size_t output_height;
97   size_t output_width;
98   size_t output_pixel_stride;
99   void* output;
100 
101   void* packed_weights;
102   float input_scale;
103   float output_scale;
104   uint8_t input_zero_point;
105   uint8_t kernel_zero_point;
106   uint8_t output_zero_point;
107   uint8_t output_min;
108   uint8_t output_max;
109 
110   size_t valid_batch_size;
111   size_t last_input_height;
112   size_t last_input_width;
113   const void* last_input;
114 
115   void* zero_buffer;
116   void* zero_pointer;
117   void* lookup_table;
118 
119   union {
120     union pytorch_qnnp_q31_requantization_params requantization_params;
121     union pytorch_qnnp_conv_quantization_params conv_quantization_params;
122     union pytorch_qnnp_add_quantization_params add_quantization_params;
123     union pytorch_qnnp_avgpool_quantization_params avgpool_quantization_params;
124     union pytorch_qnnp_u8_clamping_params u8_clamping_params;
125   };
126   enum pytorch_qnnp_ukernel_type ukernel_type;
127   enum pytorch_qnnp_format format;
128 
129   bool per_channel;
130   bool transpose;
131 
132   // Sparsity support
133   sparse_matrix_t sparse_matrix;
134   const void* bias;
135   struct pytorch_qnnp_conv_dynamic_quantization_params dynamic_conv_quantization_params;
136   uint8_t* prepacked_a;
137 };
138 
pytorch_qnnp_operator_get_log2_output_element_size(const struct pytorch_qnnp_operator * convolution)139 static inline uint32_t pytorch_qnnp_operator_get_log2_output_element_size(
140     const struct pytorch_qnnp_operator* convolution) {
141   return (uint32_t)(convolution->format & UINT32_C(0xFF));
142 }
143 
pytorch_qnnp_operator_get_log2_input_element_size(const struct pytorch_qnnp_operator * convolution)144 static inline uint32_t pytorch_qnnp_operator_get_log2_input_element_size(
145     const struct pytorch_qnnp_operator* convolution) {
146   return (uint32_t)((convolution->format >> 8) & UINT32_C(0xFF));
147 }
148 
pytorch_qnnp_operator_get_log2_kernel_element_size(const struct pytorch_qnnp_operator * convolution)149 static inline uint32_t pytorch_qnnp_operator_get_log2_kernel_element_size(
150     const struct pytorch_qnnp_operator* convolution) {
151   return (uint32_t)((convolution->format >> 16) & UINT32_C(0xFF));
152 }
153 
pytorch_qnnp_operator_get_log2_bias_element_size(const struct pytorch_qnnp_operator * convolution)154 static inline uint32_t pytorch_qnnp_operator_get_log2_bias_element_size(
155     const struct pytorch_qnnp_operator* convolution) {
156   return (uint32_t)((convolution->format >> 24) & UINT32_C(0xFF));
157 }
158