1 #pragma once
2
3 #ifdef USE_XNNPACK
4 #include <cstdint>
5
6 #include <ATen/core/Tensor.h>
7 #include <ATen/native/xnnpack/Common.h>
8
9 using xnnpack_operator = at::native::xnnpack::Operator;
10
11 namespace at {
12 namespace native {
13 namespace xnnp_utils {
14
15 /*
16 * Return shape in the same order as the memory format
17 * e.g. channels_last will return NHWC instead of NCHW
18 */
19 std::vector<size_t> get_mem_format_aware_shape(const at::Tensor& in);
20
21 /*
22 * Input is always int8_t, output can be [int8_t, uint8_t].
23 * input + offset = output
24 * int8_t + 128 = uint8_t
25 * int8_t + 0 = int8_t
26 */
27 template <typename PT>
28 void q8_copy_int8_weight_and_add_offset(const at::Tensor& in, at::Tensor& out);
29
30 template <int kSpatialDim>
31 Tensor convert_conv_weights_to_channel_last_tensor(
32 const at::Tensor& src,
33 int groups,
34 bool transpose);
35
36 /*
37 * Series of create wrapper functions to call xnn_create_[de]conv* functions.
38 */
39 C10_ALWAYS_INLINE
xnnp_create_convolution2d_nhwc(uint32_t pad_top,uint32_t pad_right,uint32_t pad_bottom,uint32_t pad_left,uint32_t kernel_h,uint32_t kernel_w,uint32_t stride_h,uint32_t stride_w,uint32_t dilation_h,uint32_t dilation_w,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t ip_chan_stride,size_t op_chan_stride,int8_t izp,float ip_scale,int8_t kzp,const float * k_scales,const int8_t * kernel,const int32_t * bias,int8_t ozp,float op_scale,int8_t op_min,int8_t op_max,uint32_t flags,xnn_operator_t * op,bool per_channel,bool transpose)40 enum xnn_status xnnp_create_convolution2d_nhwc(
41 uint32_t pad_top,
42 uint32_t pad_right,
43 uint32_t pad_bottom,
44 uint32_t pad_left,
45 uint32_t kernel_h,
46 uint32_t kernel_w,
47 uint32_t stride_h,
48 uint32_t stride_w,
49 uint32_t dilation_h,
50 uint32_t dilation_w,
51 uint32_t groups,
52 size_t group_input_channels,
53 size_t group_output_channels,
54 size_t ip_chan_stride,
55 size_t op_chan_stride,
56 int8_t izp,
57 float ip_scale,
58 int8_t kzp,
59 const float* k_scales,
60 const int8_t* kernel,
61 const int32_t* bias,
62 int8_t ozp,
63 float op_scale,
64 int8_t op_min,
65 int8_t op_max,
66 uint32_t flags,
67 xnn_operator_t* op,
68 bool per_channel,
69 bool transpose) {
70 /* Symmetric quantization forces kzp = 0 */
71 TORCH_CHECK(!kzp, "XNNPACK Q[SC]8 conv kernels expects kernel zero point to be zero."
72 "But got: ", kzp);
73
74 if (transpose) {
75 TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
76 return xnn_create_deconvolution2d_nhwc_qs8(
77 pad_top, /* uint32_t output_padding_top */
78 pad_right, /* uint32_t output_padding_right */
79 pad_bottom, /* uint32_t output_padding_bottom */
80 pad_left, /* uint32_t output_padding_left */
81 kernel_h, /* uint32_t kernel_height */
82 kernel_w, /* uint32_t kernel_width */
83 stride_h, /* uint32_t stride_height */
84 stride_w, /* uint32_t stride_width */
85 dilation_h, /* uint32_t dilation_height */
86 dilation_w, /* uint32_t dilation_width */
87 groups, /* uint32_t groups */
88 group_input_channels, /* size_t group_input_channels */
89 group_output_channels, /* size_t group_output_channels */
90 ip_chan_stride, /* size_t input_pixel_stride */
91 op_chan_stride, /* size_t output_pixel_stride */
92 izp, /* int8_t input_zero_point */
93 ip_scale, /* float input_scale */
94 k_scales[0], /* float kernel_scale */
95 kernel, /* const int8_t* kernel */
96 bias, /* const int32_t* bias */
97 ozp, /* int8_t output_zero_point */
98 op_scale, /* float output_scale */
99 op_min, /* int8_t output_min */
100 op_max, /* int8_t output_max */
101 flags, /* uint32_t flags */
102 nullptr, /* xnn_caches_t caches */
103 nullptr, /* xnn_weights_cache_t weights_cache */
104 op); /* xnn_operator_t* deconvolution_op_out */
105
106 }
107
108 if (!per_channel) {
109 return xnn_create_convolution2d_nhwc_qs8(
110 pad_top, /* uint32_t input_padding_top */
111 pad_right, /* uint32_t input_padding_right */
112 pad_bottom, /* uint32_t input_padding_bottom */
113 pad_left, /* uint32_t input_padding_left */
114 kernel_h, /* uint32_t kernel_height */
115 kernel_w, /* uint32_t kernel_width */
116 stride_h, /* uint32_t subsampling_height */
117 stride_w, /* uint32_t subsampling_width */
118 dilation_h, /* uint32_t dilation_height */
119 dilation_w, /* uint32_t dilation_width */
120 groups, /* uint32_t groups */
121 group_input_channels, /* size_t group_input_channels */
122 group_output_channels, /* size_t group_output_channels*/
123 ip_chan_stride, /* size_t input_channel_stride */
124 op_chan_stride, /* size_t output_channel_stride */
125 izp, /* int8_t input_zero_point */
126 ip_scale, /* float input_scale */
127 k_scales[0], /* float kernel_scale */
128 kernel, /* const int8_t* kernel */
129 bias, /* const int32_t* bias */
130 ozp, /* int8_t output_zero_point */
131 op_scale, /* float output_scale */
132 op_min, /* int8_t output_min */
133 op_max, /* int8_t output_max */
134 flags, /* uint32_t flags */
135 nullptr, /* xnn_caches_t caches */
136 nullptr, /* xnn_weights_cache_t weights_cache */
137 op); /* xnn_operator_t* convolution_op_out */
138 } else { /* per_channel */
139 return xnn_create_convolution2d_nhwc_qs8_qc8w(
140 pad_top, /* uint32_t input_padding_top */
141 pad_right, /* uint32_t input_padding_right */
142 pad_bottom, /* uint32_t input_padding_bottom */
143 pad_left, /* uint32_t input_padding_left */
144 kernel_h, /* uint32_t kernel_height */
145 kernel_w, /* uint32_t kernel_width */
146 stride_h, /* uint32_t subsampling_height */
147 stride_w, /* uint32_t subsampling_width */
148 dilation_h, /* uint32_t dilation_height */
149 dilation_w, /* uint32_t dilation_width */
150 groups, /* uint32_t groups */
151 group_input_channels, /* size_t group_input_channels */
152 group_output_channels, /* size_t group_output_channels*/
153 ip_chan_stride, /* size_t input_channel_stride */
154 op_chan_stride, /* size_t output_channel_stride */
155 izp, /* int8_t input_zero_point */
156 ip_scale, /* float input_scale */
157 k_scales, /* const float* kernel_scale */
158 kernel, /* const int8_t* kernel */
159 bias, /* const int32_t* bias */
160 ozp, /* int8_t output_zero_point */
161 op_scale, /* float output_scale */
162 op_min, /* int8_t output_min */
163 op_max, /* int8_t output_max */
164 flags, /* uint32_t flags */
165 nullptr, /* xnn_caches_t caches */
166 nullptr, /* xnn_weights_cache_t weights_cache */
167 op); /* xnn_operator_t* convolution_op_out */
168 }
169 }
170
171 /*
172 * Series of reshape wrapper functions to call xnn_reshape_[de]conv* functions.
173 */
174 C10_ALWAYS_INLINE
175 enum xnn_status xnnp_reshape_convolution2d_nhwc(
176 xnn_operator_t op,
177 size_t batch,
178 size_t in_h,
179 size_t in_w,
180 pthreadpool_t pt_pool,
181 bool per_channel = false,
182 bool transpose = false,
183 uint32_t adj_h = 0,
184 uint32_t adj_w = 0) {
185 if(transpose) {
186 TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
187 return xnn_reshape_deconvolution2d_nhwc_qs8(
188 op, /* xnn_operator_t deconvolution_op */
189 batch, /* size_t batch_size */
190 in_h, /* size_t input_height */
191 in_w, /* size_t input_width */
192 adj_h, /* uint32_t adjustment_height */
193 adj_w, /* uint32_t adjustment_width */
194 nullptr, /* size_t* output_height_out */
195 nullptr, /* size_t* output_width_out */
196 pt_pool); /* pthreadpool_t threadpool */
197 }
198
199 size_t workspace_size = SIZE_MAX;
200 size_t workspace_alignment = SIZE_MAX;
201
202 if (!per_channel) {
203 return xnn_reshape_convolution2d_nhwc_qs8(
204 op, /* xnn_operator_t convolution_op */
205 batch, /* size_t batch_size */
206 in_h, /* size_t input_height */
207 in_w, /* size_t input_width */
208 &workspace_size, /* size_t* workspace_size */
209 &workspace_alignment, /* size_t* workspace_alignment */
210 nullptr, /* size_t* output_height_out */
211 nullptr, /* size_t* output_width_out */
212 pt_pool); /* pthreadpool_t threadpool */
213 } else { /* per_channel */
214 return xnn_reshape_convolution2d_nhwc_qs8_qc8w(
215 op, /* xnn_operator_t convolution_op */
216 batch, /* size_t batch_size */
217 in_h, /* size_t input_height */
218 in_w, /* size_t input_width */
219 &workspace_size, /* size_t* workspace_size */
220 &workspace_alignment, /* size_t* workspace_alignment */
221 nullptr, /* size_t* output_height_out */
222 nullptr, /* size_t* output_width_out */
223 pt_pool); /* pthreadpool_t threadpool */
224 }
225 }
226
227
228 /*
229 * Series of setup wrapper functions to call xnn_setup_[de]conv* functions.
230 */
231 C10_ALWAYS_INLINE
232 enum xnn_status xnnp_setup_convolution2d_nhwc(
233 xnn_operator_t op,
234 const int8_t* inp,
235 int8_t* outp,
236 bool per_channel = false,
237 bool transpose = false) {
238 if(transpose) {
239 TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
240
241 return xnn_setup_deconvolution2d_nhwc_qs8(
242 op, /* xnn_operator_t deconvolution_op */
243 inp, /* const int8_t* input */
244 outp); /* int8_t* output */
245 }
246
247 if (!per_channel) {
248 return xnn_setup_convolution2d_nhwc_qs8(
249 op, /* xnn_operator_t deconvolution_op */
250 nullptr, /* void workspace */
251 inp, /* const int8_t* input */
252 outp); /* int8_t* output */
253 } else { /* per_channel */
254 return xnn_setup_convolution2d_nhwc_qs8_qc8w(
255 op, /* xnn_operator_t deconvolution_op */
256 nullptr, /* void workspace */
257 inp, /* const int8_t* input */
258 outp); /* int8_t* output */
259 }
260 }
261
262
263 /*
264 * Series of wrapper functions to call xnn_create* and xnn_setup*
265 * functions for linear
266 */
267 C10_ALWAYS_INLINE
xnnp_create_fully_connected_nc(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,int8_t input_zero_point,float input_scale,int8_t kernel_zero_point,float kernel_scale,const int8_t * kernel,const int32_t * bias,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * fully_connected_op_out)268 enum xnn_status xnnp_create_fully_connected_nc(
269 size_t input_channels,
270 size_t output_channels,
271 size_t input_stride,
272 size_t output_stride,
273 int8_t input_zero_point,
274 float input_scale,
275 int8_t kernel_zero_point,
276 float kernel_scale,
277 const int8_t* kernel,
278 const int32_t* bias,
279 int8_t output_zero_point,
280 float output_scale,
281 int8_t output_min,
282 int8_t output_max,
283 uint32_t flags,
284 xnn_operator_t* fully_connected_op_out) {
285 /* Symmetric quantization forces kzp = 0 */
286 TORCH_CHECK(!kernel_zero_point, "XNNPACK QS8 linear kernel expects kernel zero point to be zero."
287 "But got: ", kernel_zero_point);
288 return xnn_create_fully_connected_nc_qs8(
289 input_channels, /* size_t input_channels */
290 output_channels, /* size_t output_channels */
291 input_stride, /* size_t input_stride */
292 output_stride, /* size_t output_stride */
293 input_zero_point, /* int8_t input_zero_point */
294 input_scale, /* float input_scale */
295 kernel_scale, /* float kernel_scale */
296 kernel, /* const int8_t* kernel */
297 bias, /* const int32_t* bias */
298 output_zero_point, /* int8_t output_zero_point */
299 output_scale, /* float output_scale */
300 output_min, /* int8_t output_min */
301 output_max, /* int8_t output_max */
302 flags, /* uint32_t flags */
303 nullptr, /* xnn_caches_t caches */
304 nullptr, /* xnn_weights_cache_t */
305 fully_connected_op_out); /* xnn_operator_t* fully_connected_op_out */
306 }
307
308 C10_ALWAYS_INLINE
xnnp_reshape_fully_connected_nc(xnn_operator_t fully_connected_op,size_t batch_size,pthreadpool_t threadpool)309 enum xnn_status xnnp_reshape_fully_connected_nc(
310 xnn_operator_t fully_connected_op,
311 size_t batch_size,
312 pthreadpool_t threadpool) {
313 return xnn_reshape_fully_connected_nc_qs8(
314 fully_connected_op, /* xnn_operator_t fully_connected_op */
315 batch_size, /* size_t batch_size */
316 threadpool); /* pthreadpool_t threadpool */
317 }
318
319 C10_ALWAYS_INLINE
xnnp_setup_fully_connected_nc(xnn_operator_t fully_connected_op,const int8_t * input,int8_t * output)320 enum xnn_status xnnp_setup_fully_connected_nc(
321 xnn_operator_t fully_connected_op,
322 const int8_t* input,
323 int8_t* output) {
324 return xnn_setup_fully_connected_nc_qs8(
325 fully_connected_op, /* xnn_operator_t fully_connected_op */
326 input, /* const int8_t* input */
327 output /* int8_t* output */
328 );
329 }
330
331 } // namespace xnnp_utils
332 } // namespace native
333 } // namespace at
334
335 #endif // USE_XNNPACK
336