xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/XnnpackUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_XNNPACK
4 #include <cstdint>
5 
6 #include <ATen/core/Tensor.h>
7 #include <ATen/native/xnnpack/Common.h>
8 
9 using xnnpack_operator = at::native::xnnpack::Operator;
10 
11 namespace at {
12 namespace native {
13 namespace xnnp_utils {
14 
15 /*
16  * Return shape in the same order as the memory format
17  * e.g. channels_last will return NHWC instead of NCHW
18  */
19 std::vector<size_t> get_mem_format_aware_shape(const at::Tensor& in);
20 
21 /*
22  * Input is always int8_t, output can be [int8_t, uint8_t].
23  * input  + offset = output
24  * int8_t + 128    = uint8_t
25  * int8_t + 0      = int8_t
26  */
27 template <typename PT>
28 void q8_copy_int8_weight_and_add_offset(const at::Tensor& in, at::Tensor& out);
29 
30 template <int kSpatialDim>
31 Tensor convert_conv_weights_to_channel_last_tensor(
32     const at::Tensor& src,
33     int groups,
34     bool transpose);
35 
36 /*
37  * Series of create wrapper functions to call xnn_create_[de]conv* functions.
38  */
39 C10_ALWAYS_INLINE
xnnp_create_convolution2d_nhwc(uint32_t pad_top,uint32_t pad_right,uint32_t pad_bottom,uint32_t pad_left,uint32_t kernel_h,uint32_t kernel_w,uint32_t stride_h,uint32_t stride_w,uint32_t dilation_h,uint32_t dilation_w,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t ip_chan_stride,size_t op_chan_stride,int8_t izp,float ip_scale,int8_t kzp,const float * k_scales,const int8_t * kernel,const int32_t * bias,int8_t ozp,float op_scale,int8_t op_min,int8_t op_max,uint32_t flags,xnn_operator_t * op,bool per_channel,bool transpose)40 enum xnn_status xnnp_create_convolution2d_nhwc(
41     uint32_t pad_top,
42     uint32_t pad_right,
43     uint32_t pad_bottom,
44     uint32_t pad_left,
45     uint32_t kernel_h,
46     uint32_t kernel_w,
47     uint32_t stride_h,
48     uint32_t stride_w,
49     uint32_t dilation_h,
50     uint32_t dilation_w,
51     uint32_t groups,
52     size_t group_input_channels,
53     size_t group_output_channels,
54     size_t ip_chan_stride,
55     size_t op_chan_stride,
56     int8_t izp,
57     float ip_scale,
58     int8_t kzp,
59     const float* k_scales,
60     const int8_t* kernel,
61     const int32_t* bias,
62     int8_t ozp,
63     float op_scale,
64     int8_t op_min,
65     int8_t op_max,
66     uint32_t flags,
67     xnn_operator_t* op,
68     bool per_channel,
69     bool transpose) {
70   /* Symmetric quantization forces kzp = 0 */
71   TORCH_CHECK(!kzp, "XNNPACK Q[SC]8 conv kernels expects kernel zero point to be zero."
72                     "But got: ", kzp);
73 
74   if (transpose) {
75     TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
76     return xnn_create_deconvolution2d_nhwc_qs8(
77         pad_top,        /* uint32_t output_padding_top          */
78         pad_right,      /* uint32_t output_padding_right        */
79         pad_bottom,     /* uint32_t output_padding_bottom       */
80         pad_left,       /* uint32_t output_padding_left         */
81         kernel_h,       /* uint32_t kernel_height               */
82         kernel_w,       /* uint32_t kernel_width                */
83         stride_h,       /* uint32_t stride_height               */
84         stride_w,       /* uint32_t stride_width                */
85         dilation_h,     /* uint32_t dilation_height             */
86         dilation_w,     /* uint32_t dilation_width              */
87         groups,         /* uint32_t groups                      */
88         group_input_channels,  /* size_t group_input_channels   */
89         group_output_channels, /* size_t group_output_channels  */
90         ip_chan_stride, /* size_t input_pixel_stride            */
91         op_chan_stride, /* size_t output_pixel_stride           */
92         izp,            /* int8_t input_zero_point              */
93         ip_scale,       /* float input_scale                    */
94         k_scales[0],    /* float kernel_scale                   */
95         kernel,         /* const int8_t* kernel                 */
96         bias,           /* const int32_t* bias                  */
97         ozp,            /* int8_t output_zero_point             */
98         op_scale,       /* float output_scale                   */
99         op_min,         /* int8_t output_min                    */
100         op_max,         /* int8_t output_max                    */
101         flags,          /* uint32_t flags                       */
102         nullptr,        /* xnn_caches_t caches                  */
103         nullptr,        /* xnn_weights_cache_t weights_cache    */
104         op);            /* xnn_operator_t* deconvolution_op_out */
105 
106   }
107 
108   if (!per_channel) {
109     return xnn_create_convolution2d_nhwc_qs8(
110         pad_top,        /* uint32_t input_padding_top         */
111         pad_right,      /* uint32_t input_padding_right       */
112         pad_bottom,     /* uint32_t input_padding_bottom      */
113         pad_left,       /* uint32_t input_padding_left        */
114         kernel_h,       /* uint32_t kernel_height             */
115         kernel_w,       /* uint32_t kernel_width              */
116         stride_h,       /* uint32_t subsampling_height        */
117         stride_w,       /* uint32_t subsampling_width         */
118         dilation_h,     /* uint32_t dilation_height           */
119         dilation_w,     /* uint32_t dilation_width            */
120         groups,         /* uint32_t groups                    */
121         group_input_channels,  /* size_t group_input_channels */
122         group_output_channels, /* size_t group_output_channels*/
123         ip_chan_stride, /* size_t input_channel_stride        */
124         op_chan_stride, /* size_t output_channel_stride       */
125         izp,            /* int8_t input_zero_point            */
126         ip_scale,       /* float input_scale                  */
127         k_scales[0],    /* float kernel_scale                 */
128         kernel,         /* const int8_t* kernel               */
129         bias,           /* const int32_t* bias                */
130         ozp,            /* int8_t output_zero_point           */
131         op_scale,       /* float output_scale                 */
132         op_min,         /* int8_t output_min                  */
133         op_max,         /* int8_t output_max                  */
134         flags,          /* uint32_t flags                     */
135         nullptr,        /* xnn_caches_t caches                */
136         nullptr,        /* xnn_weights_cache_t weights_cache    */
137         op);            /* xnn_operator_t* convolution_op_out */
138   } else { /* per_channel */
139     return xnn_create_convolution2d_nhwc_qs8_qc8w(
140         pad_top,        /* uint32_t input_padding_top         */
141         pad_right,      /* uint32_t input_padding_right       */
142         pad_bottom,     /* uint32_t input_padding_bottom      */
143         pad_left,       /* uint32_t input_padding_left        */
144         kernel_h,       /* uint32_t kernel_height             */
145         kernel_w,       /* uint32_t kernel_width              */
146         stride_h,       /* uint32_t subsampling_height        */
147         stride_w,       /* uint32_t subsampling_width         */
148         dilation_h,     /* uint32_t dilation_height           */
149         dilation_w,     /* uint32_t dilation_width            */
150         groups,         /* uint32_t groups                    */
151         group_input_channels,  /* size_t group_input_channels */
152         group_output_channels, /* size_t group_output_channels*/
153         ip_chan_stride, /* size_t input_channel_stride        */
154         op_chan_stride, /* size_t output_channel_stride       */
155         izp,            /* int8_t input_zero_point            */
156         ip_scale,       /* float input_scale                  */
157         k_scales,       /* const float* kernel_scale          */
158         kernel,         /* const int8_t* kernel               */
159         bias,           /* const int32_t* bias                */
160         ozp,            /* int8_t output_zero_point           */
161         op_scale,       /* float output_scale                 */
162         op_min,         /* int8_t output_min                  */
163         op_max,         /* int8_t output_max                  */
164         flags,          /* uint32_t flags                     */
165         nullptr,        /* xnn_caches_t caches                */
166         nullptr,        /* xnn_weights_cache_t weights_cache    */
167         op);            /* xnn_operator_t* convolution_op_out */
168   }
169 }
170 
171 /*
172  * Series of reshape wrapper functions to call xnn_reshape_[de]conv* functions.
173  */
174 C10_ALWAYS_INLINE
175 enum xnn_status xnnp_reshape_convolution2d_nhwc(
176     xnn_operator_t op,
177     size_t batch,
178     size_t in_h,
179     size_t in_w,
180     pthreadpool_t pt_pool,
181     bool per_channel = false,
182     bool transpose = false,
183     uint32_t adj_h = 0,
184     uint32_t adj_w = 0) {
185   if(transpose) {
186     TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
187     return xnn_reshape_deconvolution2d_nhwc_qs8(
188         op,       /* xnn_operator_t deconvolution_op */
189         batch,    /* size_t batch_size               */
190         in_h,     /* size_t input_height             */
191         in_w,     /* size_t input_width              */
192         adj_h,    /* uint32_t adjustment_height      */
193         adj_w,    /* uint32_t adjustment_width       */
194         nullptr,  /* size_t* output_height_out       */
195         nullptr,  /* size_t* output_width_out        */
196         pt_pool); /* pthreadpool_t threadpool        */
197   }
198 
199   size_t workspace_size = SIZE_MAX;
200   size_t workspace_alignment = SIZE_MAX;
201 
202   if (!per_channel) {
203     return xnn_reshape_convolution2d_nhwc_qs8(
204         op,       /* xnn_operator_t convolution_op */
205         batch,    /* size_t batch_size             */
206         in_h,     /* size_t input_height           */
207         in_w,     /* size_t input_width            */
208         &workspace_size, /* size_t* workspace_size */
209         &workspace_alignment, /* size_t* workspace_alignment */
210         nullptr,  /* size_t* output_height_out     */
211         nullptr,  /* size_t* output_width_out      */
212         pt_pool); /* pthreadpool_t threadpool      */
213   } else { /* per_channel */
214     return xnn_reshape_convolution2d_nhwc_qs8_qc8w(
215         op,       /* xnn_operator_t convolution_op */
216         batch,    /* size_t batch_size             */
217         in_h,     /* size_t input_height           */
218         in_w,     /* size_t input_width            */
219         &workspace_size, /* size_t* workspace_size */
220         &workspace_alignment, /* size_t* workspace_alignment */
221         nullptr,  /* size_t* output_height_out     */
222         nullptr,  /* size_t* output_width_out      */
223         pt_pool); /* pthreadpool_t threadpool      */
224   }
225 }
226 
227 
228 /*
229  * Series of setup wrapper functions to call xnn_setup_[de]conv* functions.
230  */
231 C10_ALWAYS_INLINE
232 enum xnn_status xnnp_setup_convolution2d_nhwc(
233     xnn_operator_t op,
234     const int8_t* inp,
235     int8_t* outp,
236     bool per_channel = false,
237     bool transpose = false) {
238   if(transpose) {
239     TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
240 
241     return xnn_setup_deconvolution2d_nhwc_qs8(
242         op,       /* xnn_operator_t deconvolution_op */
243         inp,      /* const int8_t* input             */
244         outp);    /* int8_t* output                  */
245   }
246 
247   if (!per_channel) {
248     return xnn_setup_convolution2d_nhwc_qs8(
249         op,       /* xnn_operator_t deconvolution_op */
250         nullptr,  /* void workspace                  */
251         inp,      /* const int8_t* input             */
252         outp);    /* int8_t* output                  */
253   } else { /* per_channel */
254     return xnn_setup_convolution2d_nhwc_qs8_qc8w(
255         op,       /* xnn_operator_t deconvolution_op */
256         nullptr,  /* void workspace                  */
257         inp,      /* const int8_t* input             */
258         outp);    /* int8_t* output                  */
259   }
260 }
261 
262 
263 /*
264  * Series of wrapper functions to call xnn_create* and xnn_setup*
265  * functions for linear
266  */
267 C10_ALWAYS_INLINE
xnnp_create_fully_connected_nc(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,int8_t input_zero_point,float input_scale,int8_t kernel_zero_point,float kernel_scale,const int8_t * kernel,const int32_t * bias,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * fully_connected_op_out)268 enum xnn_status xnnp_create_fully_connected_nc(
269     size_t input_channels,
270     size_t output_channels,
271     size_t input_stride,
272     size_t output_stride,
273     int8_t input_zero_point,
274     float input_scale,
275     int8_t kernel_zero_point,
276     float kernel_scale,
277     const int8_t* kernel,
278     const int32_t* bias,
279     int8_t output_zero_point,
280     float output_scale,
281     int8_t output_min,
282     int8_t output_max,
283     uint32_t flags,
284     xnn_operator_t* fully_connected_op_out) {
285   /* Symmetric quantization forces kzp = 0 */
286   TORCH_CHECK(!kernel_zero_point, "XNNPACK QS8 linear kernel expects kernel zero point to be zero."
287                     "But got: ", kernel_zero_point);
288   return xnn_create_fully_connected_nc_qs8(
289       input_channels,          /* size_t input_channels                  */
290       output_channels,         /* size_t output_channels                 */
291       input_stride,            /* size_t input_stride                    */
292       output_stride,           /* size_t output_stride                   */
293       input_zero_point,        /* int8_t input_zero_point                */
294       input_scale,             /* float input_scale                      */
295       kernel_scale,            /* float kernel_scale                     */
296       kernel,                  /* const int8_t* kernel                   */
297       bias,                    /* const int32_t* bias                    */
298       output_zero_point,       /* int8_t output_zero_point               */
299       output_scale,            /* float output_scale                     */
300       output_min,              /* int8_t output_min                      */
301       output_max,              /* int8_t output_max                      */
302       flags,                   /* uint32_t flags                         */
303       nullptr,                 /* xnn_caches_t caches                    */
304       nullptr,                 /* xnn_weights_cache_t                    */
305       fully_connected_op_out); /* xnn_operator_t* fully_connected_op_out */
306 }
307 
308 C10_ALWAYS_INLINE
xnnp_reshape_fully_connected_nc(xnn_operator_t fully_connected_op,size_t batch_size,pthreadpool_t threadpool)309 enum xnn_status xnnp_reshape_fully_connected_nc(
310     xnn_operator_t fully_connected_op,
311     size_t batch_size,
312     pthreadpool_t threadpool) {
313   return xnn_reshape_fully_connected_nc_qs8(
314       fully_connected_op, /* xnn_operator_t fully_connected_op */
315       batch_size,         /* size_t batch_size                 */
316       threadpool);        /* pthreadpool_t threadpool          */
317 }
318 
319 C10_ALWAYS_INLINE
xnnp_setup_fully_connected_nc(xnn_operator_t fully_connected_op,const int8_t * input,int8_t * output)320 enum xnn_status xnnp_setup_fully_connected_nc(
321     xnn_operator_t fully_connected_op,
322     const int8_t* input,
323     int8_t* output) {
324   return xnn_setup_fully_connected_nc_qs8(
325       fully_connected_op, /* xnn_operator_t fully_connected_op */
326       input,              /* const int8_t* input               */
327       output              /* int8_t* output                    */
328     );
329 }
330 
331 } // namespace xnnp_utils
332 } // namespace native
333 } // namespace at
334 
335 #endif // USE_XNNPACK
336