1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <utility>
3 #include <vector>
4
5 #include <ATen/core/Tensor.h>
6 #include <ATen/core/List.h>
7 #include <ATen/Context.h>
8 #include <ATen/native/quantized/PackedParams.h>
9 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
10 #include <ATen/native/quantized/cpu/init_qnnpack.h>
11 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
12 #include <ATen/native/quantized/cpu/OnednnUtils.h>
13 #include <ATen/native/quantized/cpu/QuantUtils.h>
14 #include <torch/library.h>
15 #include <ATen/native/mkldnn/MKLDNNCommon.h>
16
17 #ifndef AT_PER_OPERATOR_HEADERS
18 #include <ATen/Functions.h>
19 #else
20 #include <ATen/ops/zeros.h>
21 #endif
22
23 #include <c10/util/irange.h>
24
25 #ifdef USE_FBGEMM
26 template <int kSpatialDim>
27 c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeight<
28 kSpatialDim>::
prepack(at::Tensor weight,std::optional<at::Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> output_padding,torch::List<int64_t> dilation,int64_t groups,bool transpose)29 prepack(
30 at::Tensor weight,
31 std::optional<at::Tensor> bias,
32 torch::List<int64_t> stride,
33 torch::List<int64_t> padding,
34 torch::List<int64_t> output_padding,
35 torch::List<int64_t> dilation,
36 int64_t groups,
37 bool transpose) {
38 TORCH_CHECK(
39 weight.ndimension() == kSpatialDim + 2,
40 "Weights are expected to have ",
41 kSpatialDim + 2,
42 " dimensions");
43 TORCH_CHECK(
44 stride.size() == kSpatialDim,
45 "stride should contain ",
46 kSpatialDim,
47 " elements for ",
48 kSpatialDim,
49 "D convolution.");
50 TORCH_CHECK(
51 padding.size() == kSpatialDim,
52 "Specify front/top/left padding only. "
53 "end/bottom/right padding assumed to be equal to front/top/left");
54 TORCH_CHECK(
55 !transpose || output_padding.size() == kSpatialDim,
56 "quantized::conv_prepack: Specify top/left output padding "
57 "only. bottom/right padding assumed to be equal to top/left");
58 TORCH_CHECK(
59 dilation.size() == kSpatialDim,
60 "dilation should contain ",
61 kSpatialDim,
62 " elements for ",
63 kSpatialDim,
64 "D convolution.");
65 const int input_channels = transpose ? weight.size(0)
66 : weight.size(1) * groups;
67 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
68 const int output_channels = transpose ? weight.size(1) * groups
69 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
70 : weight.size(0);
71 const int kernel_d = kSpatialDim == 2 ? 1 : weight.size(2);
72 const int kernel_h = weight.size(kSpatialDim);
73 const int kernel_w = weight.size(kSpatialDim + 1);
74
75 // mini-batch doesn't have any impact on how we pack weights
76 // so we pass it as 1
77 // Input image height/width also don't have any impact on how we pack
78 // weights so we can pass any values
79 const fbgemm::conv_param_t<kSpatialDim> conv_p =
80 at::native::fbgemm_utils::MakeFbgemmConvParam<kSpatialDim>(
81 1, // dummy batch size
82 input_channels,
83 output_channels,
84 kSpatialDim == 2 ? std::vector<int>{28, 28} // dummy image size
85 : std::vector<int>{28, 28, 28},
86 groups,
87 kSpatialDim == 2 ? std::vector<int>{kernel_h, kernel_w}
88 : std::vector<int>{kernel_d, kernel_h, kernel_w},
89 std::vector<int>(stride.begin(), stride.end()),
90 std::vector<int>(padding.begin(), padding.end()),
91 std::vector<int>(dilation.begin(), dilation.end()),
92 std::vector<int>(output_padding.begin(), output_padding.end()),
93 transpose);
94
95 const auto qtype = weight.qscheme();
96 std::vector<int32_t> zero_points;
97 if (qtype == c10::kPerTensorAffine) {
98 zero_points = {static_cast<int32_t>(weight.q_zero_point())};
99 } else if (qtype == c10::kPerChannelAffine) {
100 TORCH_CHECK(
101 !transpose,
102 "Per Channel Quantization is currently disabled for transposed conv");
103 zero_points.resize(output_channels);
104 for (const auto i : c10::irange(output_channels)) {
105 zero_points[i] = weight.q_per_channel_zero_points()[i].item<int32_t>();
106 }
107 } else {
108 TORCH_CHECK(false, "Unsupported qscheme: ", toString(qtype));
109 }
110
111 // FBGEMM expects weights to be in channels last
112 // TODO: Change this when ChannelsLast3d is ready.
113 // FBGEMM needs G OC/G kDim0 ... kDimN IC/G
114 // for both conv and conv transpose
115 // but PyTorch lays them out as {out_c, in_c/groups, kH, kW}
116 // (or for ConvTranspose {in_c, out_c/groups, kH, kW})
117 const at::Tensor weight_nhwc =
118 at::native::fbgemm_utils::ConvertConvWeightsToChannelLastTensor<kSpatialDim>(weight, groups, transpose);
119 const int8_t* weight_data_int8 =
120 reinterpret_cast<int8_t*>(weight_nhwc.data_ptr<c10::qint8>());
121 std::vector<int32_t> col_offsets(output_channels);
122 // compute column offsets (Similar to
123 // fbgemm::col_offsets_with_zero_pt_s8acc32_ref) please note that offsets
124 // include the sum of columns as well as the scalar term weight_zero_point *
125 // KDim
126 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
127 const int input_channels_per_group = input_channels / groups;
128 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
129 const int output_channels_per_group = output_channels / groups;
130 const int inner_size =
131 kernel_d * kernel_h * kernel_w * input_channels_per_group;
132 for (const auto g : c10::irange(groups)) {
133 for (const auto i : c10::irange(output_channels_per_group)) {
134 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
135 const int c = g * output_channels_per_group + i;
136 int32_t sum = 0;
137 for (const auto j : c10::irange(inner_size)) {
138 sum += static_cast<int32_t>(weight_data_int8[c * inner_size + j]);
139 }
140 if (qtype == c10::kPerTensorAffine) {
141 col_offsets[c] = sum - zero_points[0] * inner_size;
142 } else {
143 col_offsets[c] = sum - zero_points[c] * inner_size;
144 }
145 }
146 }
147
148 std::vector<float> scales;
149 if (qtype == c10::kPerTensorAffine) {
150 scales = {static_cast<float>(weight.q_scale())};
151 } else if (qtype == c10::kPerChannelAffine) {
152 scales.resize(output_channels);
153 for (const auto i : c10::irange(output_channels)) {
154 scales[i] = weight.q_per_channel_scales()[i].item<float>();
155 }
156 }
157
158 std::optional<at::Tensor> bias_contig;
159 if (bias.has_value()) {
160 at::Tensor bias_vec = bias.value();
161 TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)");
162 TORCH_CHECK(
163 bias_vec.size(0) == output_channels,
164 "bias should have K elements: " + std::to_string(output_channels));
165 bias_contig = bias->contiguous();
166 }
167
168 auto ret_ptr = c10::make_intrusive<PackedConvWeight<kSpatialDim>>(
169 PackedConvWeight<kSpatialDim>{
170 std::make_unique<fbgemm::PackWeightsForConv<kSpatialDim>>(
171 conv_p, weight_data_int8),
172 bias_contig,
173 stride,
174 padding,
175 output_padding,
176 dilation,
177 groups,
178 transpose,
179 col_offsets,
180 kSpatialDim == 2 ? std::vector<int64_t>{kernel_h, kernel_w}
181 : std::vector<int64_t>{kernel_d, kernel_h, kernel_w},
182 scales,
183 zero_points,
184 qtype});
185
186 return ret_ptr;
187 }
188
189 template struct PackedConvWeight<2>;
190 template struct PackedConvWeight<3>;
191 #endif // USE_FBGEMM
192
193 #ifdef USE_PYTORCH_QNNPACK
194 template <int kSpatialDim>
195 c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightsQnnp<
196 kSpatialDim>::
prepack(at::Tensor weight,std::optional<at::Tensor> bias_in,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> output_padding,torch::List<int64_t> dilation,int64_t groups,bool transpose)197 prepack(
198 at::Tensor weight,
199 std::optional<at::Tensor> bias_in,
200 torch::List<int64_t> stride,
201 torch::List<int64_t> padding,
202 torch::List<int64_t> output_padding,
203 torch::List<int64_t> dilation,
204 int64_t groups,
205 bool transpose) {
206 TORCH_CHECK(
207 kSpatialDim == 2 || kSpatialDim == 3, // 1D is packed as 2d, hence we don't need other checks
208 "QNNPACK packing only supports 2D / 3D convolution.");
209 TORCH_CHECK(
210 weight.ndimension() == kSpatialDim + 2,
211 "quantized::conv_prepack (qnnpack): Weights are expected to have ",
212 kSpatialDim + 2, " dimensions, found shape ", weight.sizes());
213 TORCH_CHECK(
214 stride.size() == kSpatialDim,
215 "quantized::conv_prepack (qnnpack): ",
216 kSpatialDim, "D convolution expects stride to have ",
217 kSpatialDim, " elements.");
218 TORCH_CHECK(
219 padding.size() == kSpatialDim,
220 "quantized::conv_prepack (qnnpack): Specify top/left input padding "
221 "only. bottom/right padding assumed to be equal to top/left");
222 TORCH_CHECK(
223 !transpose || output_padding.size() == kSpatialDim,
224 "quantized::conv_prepack (qnnpack): Specify top/left output padding "
225 "only. bottom/right padding assumed to be equal to top/left");
226 TORCH_CHECK(
227 dilation.size() == kSpatialDim,
228 "quantized::conv_prepack (qnnpack): ",
229 kSpatialDim, "D convolution expects dilation to have ",
230 kSpatialDim, " elements.");
231
232 at::native::initQNNPACK();
233
234 // QNNPACK expects weights to be of the format {out_c, kH, kW, in_c/groups},
235 // but PyTorch lays them out as {out_c, in_c/groups, kH, kW}
236 // (or for ConvTranspose {in_c, out_c/groups, kH, kW})
237 const auto out_ch = transpose ? weight.size(1) * groups : weight.size(0);
238 const uint32_t kernel_d = kSpatialDim == 3 ? weight.size(2) : 1;
239 const uint32_t kernel_h = weight.size(kSpatialDim);
240 const uint32_t kernel_w = weight.size(kSpatialDim + 1);
241
242 at::Tensor bias_fp32;
243 if (bias_in.has_value()) {
244 bias_fp32 = bias_in.value();
245 } else {
246 bias_fp32 = at::zeros(out_ch, weight.options().dtype(at::kFloat));
247 }
248
249 TORCH_CHECK(
250 !bias_fp32.defined() ||
251 (bias_fp32.ndimension() == 1 && bias_fp32.size(0) == out_ch),
252 "quantized::conv2d_prepack (qnnpack): expected bias to be 1-dimensional "
253 "with ",
254 out_ch,
255 " elements",
256 ", but got bias of size ",
257 bias_fp32.sizes(),
258 " instead. "
259 "(weight dimensions: ",
260 weight.sizes(), " , transpose: ",
261 (transpose ? "True)." : "False).")
262 );
263
264 TORCH_CHECK(
265 !bias_fp32.defined() ||
266 (bias_fp32.ndimension() == 1 && bias_fp32.size(0) == out_ch),
267 "quantized::conv3d_prepack (qnnpack): expected bias to be 1-dimensional "
268 "with ",
269 out_ch,
270 " elements",
271 ", but got bias of size ",
272 bias_fp32.sizes(),
273 " instead. "
274 "(weight dimensions: ",
275 weight.sizes(), " , transpose: ",
276 (transpose ? "True)." : "False).")
277 );
278
279 auto weight_contig = weight.contiguous(
280 kSpatialDim == 2 ? c10::MemoryFormat::ChannelsLast
281 : c10::MemoryFormat::ChannelsLast3d);
282 const bool is_per_channel = weight_contig.qscheme() == at::kPerChannelAffine;
283 auto kernel_dim = kSpatialDim == 2
284 ? std::vector<int64_t>{kernel_h, kernel_w}
285 : std::vector<int64_t>{kernel_d, kernel_h, kernel_w};
286 auto [w_zero_points, w_scales] =
287 make_zero_points_and_scales_tensor(weight_contig, transpose, groups);
288 // We set the pre-packed conv weights to nullptr below as we call pre-pack
289 // during the first invocation of operator run. Refer to qconv.cpp for more
290 // details. TODO Update to actually call pre-pack here once bias is removed
291 // from pre-packing step.
292 auto ret_ptr = c10::intrusive_ptr<PackedConvWeightsQnnp<kSpatialDim>>::make(
293 nullptr, /* PrePackConvWeights */
294 weight_contig, /* int8_t weight */
295 bias_fp32.contiguous(), /* fp32 bias */
296 stride,
297 padding,
298 output_padding,
299 dilation,
300 groups,
301 transpose,
302 std::nullopt, /* input_scale */
303 kernel_dim,
304 w_scales,
305 std::move(w_zero_points),
306 is_per_channel);
307
308 return ret_ptr;
309 }
310
311 template
312 c10::intrusive_ptr<ConvPackedParamsBase<2>> PackedConvWeightsQnnp<
313 2>::
314 prepack(
315 at::Tensor weight,
316 std::optional<at::Tensor> bias_in,
317 torch::List<int64_t> stride,
318 torch::List<int64_t> padding,
319 torch::List<int64_t> output_padding,
320 torch::List<int64_t> dilation,
321 int64_t groups,
322 bool transpose);
323 #endif // USE_PYTORCH_QNNPACK
324
325 #if AT_MKLDNN_ENABLED()
326 template <int kSpatialDim>
327 c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightsOnednn<
328 kSpatialDim>::
prepack(at::Tensor weight,std::optional<at::Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> output_padding,torch::List<int64_t> dilation,int64_t groups,bool transpose)329 prepack(
330 at::Tensor weight,
331 std::optional<at::Tensor> bias,
332 torch::List<int64_t> stride,
333 torch::List<int64_t> padding,
334 torch::List<int64_t> output_padding,
335 torch::List<int64_t> dilation,
336 int64_t groups,
337 bool transpose) {
338 TORCH_CHECK(
339 weight.ndimension() == kSpatialDim + 2,
340 "Weights are expected to have ", kSpatialDim + 2, " dimensions");
341 TORCH_CHECK(
342 stride.size() == kSpatialDim,
343 "stride should contain ", kSpatialDim, " elements for ",
344 kSpatialDim, "D convolution.");
345 TORCH_CHECK(
346 padding.size() == kSpatialDim,
347 "Specify front/top/left padding only. "
348 "end/bottom/right padding assumed to be equal to front/top/left");
349 TORCH_CHECK(
350 !transpose || output_padding.size() == kSpatialDim,
351 "quantized::conv_prepack: Specify top/left output padding "
352 "only. bottom/right padding assumed to be equal to top/left");
353 TORCH_CHECK(
354 dilation.size() == kSpatialDim,
355 "dilation should contain ", kSpatialDim, " elements for ",
356 kSpatialDim, "D convolution.");
357 TORCH_CHECK(
358 !transpose || std::all_of(output_padding.begin(), output_padding.end(), [](int i) { return i==0; }),
359 "quantized::conv_prepack: ONEDNN only supports zero output_padding.");
360
361 // Weight
362 // Format: [OC IC//group KH KW] for conv; [IC OC//group KH KW] for deconv
363 auto dims = weight.sizes().vec();
364 auto strides = stride.vec();
365 auto padding_l = padding.vec();
366 auto padding_r = padding.vec();
367 auto dilates = dilation.vec();
368 auto op_attr = ideep::attr_t();
369 std::vector<int32_t> wgt_zero_points;
370 ideep::scale_t wgt_scales;
371 const int output_channels = transpose ? weight.size(1) * groups
372 : weight.size(0);
373 const auto qtype = weight.qscheme();
374 if (qtype == c10::kPerTensorAffine) {
375 TORCH_CHECK(
376 weight.q_zero_point()==0,
377 "quantized::qconv_prepack: ONEDNN only supports symmetric quantization of weight,"
378 " whose zero point must be 0.");
379 wgt_zero_points = std::vector<int32_t>(1, weight.q_zero_point());
380 #if IDEEP_PREREQ(3, 1, 0, 1)
381 wgt_scales = ideep::scale_t(1, weight.q_scale());
382 #elif IDEEP_PREREQ(3, 1, 0, 0)
383 wgt_scales = ideep::scale_t(1, 1.0/weight.q_scale()); // Scales of ONEDNN and PyTorch are reciprocal
384 #else
385 TORCH_CHECK(false, "Unexpected IDeep version to do qconv weight prepack.");
386 #endif
387 } else if (qtype == c10::kPerChannelAffine) {
388 TORCH_CHECK(
389 !transpose,
390 "Per Channel Quantization is currently disabled for transposed conv");
391 wgt_zero_points.resize(output_channels);
392 wgt_scales.resize(output_channels);
393 for (int i = 0; i < output_channels; ++i) {
394 wgt_zero_points[i] = weight.q_per_channel_zero_points()[i].item<int32_t>();
395 TORCH_CHECK(
396 wgt_zero_points[i]==0,
397 "quantized::qconv_prepack: ONEDNN only supports symmetric quantization of weight,"
398 " whose zero point must be 0.");
399 #if IDEEP_PREREQ(3, 1, 0, 1)
400 wgt_scales[i] = weight.q_per_channel_scales()[i].item<float>();
401 #elif IDEEP_PREREQ(3, 1, 0, 0)
402 wgt_scales[i] = 1.0f / weight.q_per_channel_scales()[i].item<float>(); // Scales of ONEDNN and PyTorch are reciprocal
403 #else
404 TORCH_CHECK(false, "Unexpected IDeep version to do qconv weight prepack.");
405 #endif
406 }
407 } else {
408 TORCH_CHECK(false, "Unsupported qscheme: ", toString(qtype));
409 }
410
411 // Set runtime src zero point
412 op_attr.set_zero_points_mask(DNNL_ARG_SRC, /* zero_points_mask= */0);
413 at::Tensor weight_copy;
414 ideep::tensor::desc w_desc;
415 ideep::dims dims_iohw, dims_giohw;
416 ideep::tag w_tag = ideep::tag::any;
417 const bool with_groups = groups > 1;
418 if (transpose) {
419 // template args: <(src/dst) is_channels_last, transposed>
420 w_desc = ideep::convolution_transpose_forward::expected_weights_desc<true, false>(
421 dims, dnnl::memory::data_type::s8,
422 strides, padding_l, padding_r, dilates, groups,
423 dnnl::algorithm::deconvolution_direct, dnnl::prop_kind::forward_inference,
424 ideep::dims(), op_attr);
425 // convolution_transpose_forward::expected_weights_desc() gives format [i, o, ...],
426 // but ONEDNN requires [o, i, ...] for computation
427 dims_iohw = w_desc.get_dims();
428 dims_giohw = with_groups ? ideep::utils::group_dims(dims_iohw, groups) : dims_iohw;
429 std::vector<int64_t> perms(dims_giohw.size(), 0); // for permutation of weight
430 std::iota(perms.begin(), perms.end(), 0);
431 std::swap(perms[with_groups], perms[with_groups + 1]);
432 weight_copy = weight.reshape(dims_giohw).permute(c10::IntArrayRef(perms)).clone();
433 } else {
434 w_desc = ideep::convolution_forward::expected_weights_desc(
435 dims, dnnl::memory::data_type::s8,
436 strides, padding_l, padding_r, dilates, groups,
437 dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference,
438 dnnl::memory::data_type::u8, ideep::dims(), op_attr, /*is_channels_last=*/true);
439 weight_copy = weight.clone();
440 }
441 if (with_groups) {
442 w_tag = kSpatialDim == 2 ? ideep::tag::goihw : ideep::tag::goidhw;
443 } else {
444 w_tag = kSpatialDim == 2 ? ideep::tag::oihw : ideep::tag::oidhw;
445 }
446 ideep::dims w_dims = with_groups ? ideep::utils::group_dims(w_desc.get_dims(), groups)
447 : w_desc.get_dims();
448 ideep::tensor wgt = ideep::tensor(
449 ideep::tensor::desc({w_dims, dnnl::memory::data_type::s8, w_tag}, groups),
450 weight_copy.data_ptr());
451 wgt.set_scale(wgt_scales); // Scales are needed for feed_from().
452 ideep::tensor exp_wgt;
453 exp_wgt.init(w_desc);
454 exp_wgt.set_scale(wgt_scales); // Also for feed_from()
455 exp_wgt.feed_from(wgt, transpose); // expect wgt to be in [OC IC KH KW] format
456 ideep::tensor * packed_weight_p = new ideep::tensor(std::move(exp_wgt));
457 packed_weight_p->set_scale(wgt_scales);
458 packed_weight_p->set_zero_point(wgt_zero_points);
459 std::unique_ptr<ideep::tensor> weight_ptr(packed_weight_p);
460 // Bias
461 std::optional<ideep::tensor> onednn_bias{std::nullopt};
462 if (bias.has_value()) {
463 at::Tensor bias_vec = bias.value();
464 TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)");
465 TORCH_CHECK(
466 bias_vec.size(0) == output_channels,
467 "bias should have K elements: " + std::to_string(output_channels));
468 auto bias_desc = ideep::tensor::desc(bias.value().sizes().vec(), dnnl::memory::data_type::f32);
469 ideep::tensor packed_bias;
470 packed_bias.init(bias_desc, bias.value().data_ptr());
471 onednn_bias = std::optional<ideep::tensor>(packed_bias);
472 }
473 auto ret_ptr = c10::make_intrusive<PackedConvWeightsOnednn<kSpatialDim>>(
474 PackedConvWeightsOnednn<kSpatialDim>{
475 std::move(weight_ptr),
476 onednn_bias,
477 weight,
478 bias,
479 stride,
480 padding,
481 output_padding,
482 dilation,
483 groups,
484 transpose
485 });
486 return ret_ptr;
487 }
488
489 template struct PackedConvWeightsOnednn<2>;
490 template struct PackedConvWeightsOnednn<3>;
491
492 // Return the packed weight as Mkldnn Tensor
_qconv_prepack_onednn(at::Tensor weight,at::Tensor weight_scales,double input_scale,int64_t input_zero_point,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> dilation,int64_t groups,std::optional<torch::List<int64_t>> input_shape)493 at::Tensor _qconv_prepack_onednn(
494 at::Tensor weight, // from CPU backend instead of QuantizedCPU
495 at::Tensor weight_scales, // Weight zero points must be 0 for onednn
496 double input_scale,
497 int64_t input_zero_point,
498 torch::List<int64_t> stride,
499 torch::List<int64_t> padding,
500 torch::List<int64_t> dilation,
501 int64_t groups,
502 std::optional<torch::List<int64_t>> input_shape) {
503 int kSpatialDim = weight.ndimension() - 2;
504 TORCH_CHECK(
505 weight.ndimension() == kSpatialDim + 2,
506 "Weights are expected to have ", kSpatialDim + 2, " dimensions");
507 TORCH_CHECK(
508 stride.size() == (decltype(stride.size()))kSpatialDim,
509 "stride should contain ", kSpatialDim, " elements for ",
510 kSpatialDim, "D convolution.");
511 TORCH_CHECK(
512 padding.size() == (decltype(padding.size()))kSpatialDim,
513 "Specify front/top/left padding only. "
514 "end/bottom/right padding assumed to be equal to front/top/left");
515 TORCH_CHECK(
516 dilation.size() == (decltype(dilation.size()))kSpatialDim,
517 "dilation should contain ", kSpatialDim, " elements for ",
518 kSpatialDim, "D convolution.");
519
520 bool is_1d = (1 == kSpatialDim);
521 auto x_dims = input_shape.has_value()?input_shape.value().vec():ideep::dims();
522 if (is_1d) {
523 if (input_shape.has_value()) {
524 // N, C, L -> N, C, 1, L
525 x_dims.insert(x_dims.begin() + 2, 1);
526 }
527 if (weight.dim() == 3) {
528 weight = weight.unsqueeze(quant_utils::kConv1dSqueezeDim + 2);
529 }
530 stride = quant_utils::MakeArgForConv1d(stride, 1);
531 padding = quant_utils::MakeArgForConv1d(padding, 0);
532 dilation = quant_utils::MakeArgForConv1d(dilation, 1);
533 kSpatialDim += 1;
534 }
535 auto w_dims = weight.sizes().vec();
536 auto strides = stride.vec();
537 auto padding_l = padding.vec();
538 auto padding_r = padding.vec();
539 auto dilates = dilation.vec();
540 auto op_attr = ideep::attr_t();
541
542 ideep::scale_t weights_scales(weight_scales.numel());
543
544 if (weight_scales.ndimension() == 0) {
545 // Weight is quant per tensor, then weight_scales will be a scalar Tensor
546 TORCH_CHECK(
547 weight_scales.numel() == 1,
548 "Weight is quant per tensor, weight scale expects 1 element but got ", weight_scales.numel(), " elements.");
549 #if IDEEP_PREREQ(3, 1, 0, 1)
550 weights_scales[0] = weight_scales.item().toDouble();
551 #elif IDEEP_PREREQ(3, 1, 0, 0)
552 weights_scales[0] = 1.0 / weight_scales.item().toDouble(); // Scales of ONEDNN and PyTorch are reciprocal
553 #else
554 TORCH_CHECK(false, "Unexpected IDeep version to do qconv weight prepack.");
555 #endif
556 } else {
557 // Weight is quant per channel
558 for (int i = 0; i < weight_scales.numel(); ++i) {
559 #if IDEEP_PREREQ(3, 1, 0, 1)
560 weights_scales[i] = weight_scales[i].item().toDouble();
561 #elif IDEEP_PREREQ(3, 1, 0, 0)
562 weights_scales[i] = 1.0 / weight_scales[i].item().toDouble();
563 #else
564 TORCH_CHECK(false, "Unexpected IDeep version to do qconv weight prepack.");
565 #endif
566 }
567 }
568
569 if (input_scale != 1.0f) {
570 op_attr.set_scales_mask(DNNL_ARG_SRC, /* src_scales_mask= */0);
571 }
572 if (input_zero_point != 0) {
573 op_attr.set_zero_points_mask(DNNL_ARG_SRC, /* src_zero_points_mask= */0);
574 }
575
576 at::Tensor weight_copy;
577 ideep::tensor::desc w_desc;
578 ideep::dims dims_iohw, dims_giohw;
579 ideep::tag w_tag = ideep::tag::any;
580 const bool with_groups = groups > 1;
581 w_desc = ideep::convolution_forward::expected_weights_desc(
582 w_dims, dnnl::memory::data_type::s8,
583 strides, padding_l, padding_r, dilates, groups,
584 dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference,
585 dnnl::memory::data_type::u8, x_dims, op_attr, /*is_channels_last=*/true);
586
587 // Note: Weight in Conv1D will unsqueeze into Conv2D in previous step
588 weight_copy = weight.clone(c10::MemoryFormat::Contiguous);
589
590 if (with_groups) {
591 w_tag = kSpatialDim == 2 ? ideep::tag::goihw : ideep::tag::goidhw;
592 } else {
593 w_tag = kSpatialDim == 2 ? ideep::tag::oihw : ideep::tag::oidhw;
594 }
595 ideep::dims wei_dims = with_groups ? ideep::utils::group_dims(w_desc.get_dims(), groups)
596 : w_desc.get_dims();
597 ideep::tensor wgt = ideep::tensor(
598 ideep::tensor::desc({wei_dims, dnnl::memory::data_type::s8, w_tag}, groups),
599 weight_copy.data_ptr());
600
601 wgt.set_scale(weights_scales); // Scales are needed for feed_from().
602
603 ideep::tensor exp_wgt;
604 exp_wgt.init(w_desc);
605 exp_wgt.set_scale(weights_scales); // Also for feed_from()
606 exp_wgt.feed_from(wgt, /*transposed*/false); // expect wgt to be in [OC IC KH KW] format
607
608 auto packed_weight = at::native::new_with_itensor_mkldnn(
609 std::move(exp_wgt),
610 c10::optTypeMetaToScalarType(weight_copy.options().dtype_opt()),
611 weight_copy.options().device_opt());
612
613 return packed_weight;
614 }
615
616 #endif // #if AT_MKLDNN_ENABLED()
617
618 namespace at {
619 namespace native {
620 namespace {
621
622 template <int kSpatialDim = 2>
623 class QConvPackWeightInt8 final {
624 public:
run_conv(Tensor weight,std::optional<Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> dilation,int64_t groups)625 static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> run_conv(
626 Tensor weight,
627 std::optional<Tensor> bias,
628 torch::List<int64_t> stride,
629 torch::List<int64_t> padding,
630 torch::List<int64_t> dilation,
631 int64_t groups) {
632 torch::List<int64_t> output_padding;
633 output_padding.reserve(kSpatialDim);
634 for (C10_UNUSED const auto idx : c10::irange(kSpatialDim)) {
635 output_padding.push_back((int64_t)0);
636 }
637 return _run(weight, bias, stride, padding, output_padding, dilation, groups,
638 /*transpose=*/false);
639 }
640
run_deconv(Tensor weight,std::optional<Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> output_padding,torch::List<int64_t> dilation,int64_t groups)641 static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> run_deconv(
642 Tensor weight,
643 std::optional<Tensor> bias,
644 torch::List<int64_t> stride,
645 torch::List<int64_t> padding,
646 torch::List<int64_t> output_padding,
647 torch::List<int64_t> dilation,
648 int64_t groups) {
649 return _run(weight, bias, stride, padding, output_padding, dilation, groups,
650 /*transpose=*/true);
651 }
652
653 private:
_run(Tensor weight,std::optional<Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> output_padding,torch::List<int64_t> dilation,int64_t groups,bool transpose)654 static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> _run(
655 Tensor weight,
656 std::optional<Tensor> bias,
657 torch::List<int64_t> stride,
658 torch::List<int64_t> padding,
659 torch::List<int64_t> output_padding,
660 torch::List<int64_t> dilation,
661 int64_t groups,
662 bool transpose) {
663 auto& ctx = at::globalContext();
664 #ifdef USE_FBGEMM
665 if (ctx.qEngine() == at::QEngine::X86) {
666 #if AT_MKLDNN_ENABLED()
667 bool use_onednn = onednn_utils::should_use_onednn_quant(
668 weight, transpose, groups, output_padding);
669 if (use_onednn) {
670 return PackedConvWeightsOnednn<kSpatialDim>::prepack(
671 weight, bias, stride, padding, output_padding, dilation, groups, transpose);
672 }
673 #endif
674 return PackedConvWeight<kSpatialDim>::prepack(
675 weight, bias, stride, padding, output_padding, dilation, groups, transpose);
676 } // x86
677 #endif // defined(USE_FBGEMM) || AT_MKLDNN_ENABLED()
678
679 #ifdef USE_FBGEMM
680 if (ctx.qEngine() == at::QEngine::FBGEMM) {
681 return PackedConvWeight<kSpatialDim>::prepack(
682 weight, bias, stride, padding, output_padding, dilation, groups,
683 transpose);
684 }
685 #endif
686
687 #ifdef USE_PYTORCH_QNNPACK
688 if (ctx.qEngine() == at::QEngine::QNNPACK) {
689 return PackedConvWeightsQnnp<kSpatialDim>::prepack(
690 weight, bias, stride, padding, output_padding, dilation, groups,
691 transpose);
692 }
693 #endif
694
695 #if AT_MKLDNN_ENABLED()
696 if (ctx.qEngine() == at::QEngine::ONEDNN) {
697 return PackedConvWeightsOnednn<kSpatialDim>::prepack(
698 weight, bias, stride, padding, output_padding, dilation, groups,
699 transpose);
700 }
701 #endif
702
703 TORCH_CHECK(
704 false,
705 "Didn't find engine for operation quantized::conv2d_prepack ",
706 toString(ctx.qEngine()));
707 }
708 };
709
710
711
712 class QConv1dPackWeightInt8 final {
713 public:
run_conv(Tensor weight,std::optional<Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> dilation,int64_t groups)714 static c10::intrusive_ptr<ConvPackedParamsBase<2>> run_conv(
715 Tensor weight,
716 std::optional<Tensor> bias,
717 torch::List<int64_t> stride,
718 torch::List<int64_t> padding,
719 torch::List<int64_t> dilation,
720 int64_t groups) {
721 const torch::List<int64_t> output_padding({0});
722 return _run(std::move(weight), std::move(bias), stride, padding, output_padding, dilation, groups,
723 /*transpose=*/false);
724 }
725
run_deconv(Tensor weight,std::optional<Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> output_padding,torch::List<int64_t> dilation,int64_t groups)726 static c10::intrusive_ptr<ConvPackedParamsBase<2>> run_deconv(
727 Tensor weight,
728 std::optional<Tensor> bias,
729 torch::List<int64_t> stride,
730 torch::List<int64_t> padding,
731 torch::List<int64_t> output_padding,
732 torch::List<int64_t> dilation,
733 int64_t groups) {
734 return _run(std::move(weight), std::move(bias), stride, padding, output_padding, dilation, groups,
735 /*transpose=*/true);
736 }
737
738 private:
_run(Tensor weight,std::optional<Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> output_padding,torch::List<int64_t> dilation,int64_t groups,bool transpose)739 static c10::intrusive_ptr<ConvPackedParamsBase<2>> _run(
740 Tensor weight,
741 std::optional<Tensor> bias,
742 torch::List<int64_t> stride,
743 torch::List<int64_t> padding,
744 torch::List<int64_t> output_padding,
745 torch::List<int64_t> dilation,
746 int64_t groups,
747 bool transpose) {
748 auto& ctx = at::globalContext();
749 if (weight.dim() == 3) {
750 weight = weight.unsqueeze(quant_utils::kConv1dSqueezeDim + 2);
751 }
752 stride = quant_utils::MakeArgForConv1d(stride, 1);
753 padding = quant_utils::MakeArgForConv1d(padding, 0);
754 output_padding = quant_utils::MakeArgForConv1d(output_padding, 0);
755 dilation = quant_utils::MakeArgForConv1d(dilation, 1);
756
757 #ifdef USE_FBGEMM
758 if (ctx.qEngine() == at::QEngine::X86) {
759 #if AT_MKLDNN_ENABLED()
760 bool use_onednn = onednn_utils::should_use_onednn_quant(
761 weight, transpose, groups, output_padding);
762 if (use_onednn) {
763 return PackedConvWeightsOnednn<2>::prepack(
764 weight, bias, stride, padding, output_padding, dilation, groups,
765 transpose);
766 }
767 #endif
768 return PackedConvWeight<2>::prepack(
769 std::move(weight), std::move(bias), stride, padding, output_padding, dilation, groups,
770 transpose);
771
772 } // x86
773 #endif
774
775 #ifdef USE_FBGEMM
776 if (ctx.qEngine() == at::QEngine::FBGEMM) {
777 return PackedConvWeight<2>::prepack(
778 std::move(weight), std::move(bias), stride, padding, output_padding, dilation, groups,
779 transpose);
780 }
781 #endif
782
783 #ifdef USE_PYTORCH_QNNPACK
784 if (ctx.qEngine() == at::QEngine::QNNPACK) {
785 return PackedConvWeightsQnnp<2>::prepack(
786 std::move(weight), std::move(bias), stride, padding, output_padding, dilation, groups,
787 transpose);
788 }
789 #endif
790
791 #if AT_MKLDNN_ENABLED()
792 if (ctx.qEngine() == at::QEngine::ONEDNN) {
793 return PackedConvWeightsOnednn<2>::prepack(
794 weight, bias, stride, padding, output_padding, dilation, groups,
795 transpose);
796 }
797 #endif
798
799 TORCH_CHECK(
800 false,
801 "Didn't find engine for operation quantized::conv1d_prepack ",
802 toString(ctx.qEngine()));
803 }
804 };
805
806 class QConvPrepackOneDNN final {
807 public:
run_conv(at::Tensor weight,at::Tensor weight_scales,double input_scale,int64_t input_zero_point,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> dilation,int64_t groups,std::optional<torch::List<int64_t>> input_shape)808 static at::Tensor run_conv(
809 at::Tensor weight, // from CPU backend instead of QuantizedCPU
810 at::Tensor weight_scales, // Weight zero points must be 0s for onednn
811 double input_scale,
812 int64_t input_zero_point,
813 torch::List<int64_t> stride,
814 torch::List<int64_t> padding,
815 torch::List<int64_t> dilation,
816 int64_t groups,
817 std::optional<torch::List<int64_t>> input_shape) {
818 #if AT_MKLDNN_ENABLED()
819 return _qconv_prepack_onednn(
820 weight, weight_scales, input_scale, input_zero_point,
821 stride, padding, dilation, groups, input_shape);
822 #else
823 TORCH_CHECK(false, "Unimplemented as onednn is not available.")
824 #endif
825 }
826 };
827
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)828 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
829 // Conv
830 // conv_prepack is deprecated, please use conv2d_prepack for 2D conv.
831 m.impl(TORCH_SELECTIVE_NAME("quantized::conv_prepack"), TORCH_FN(QConvPackWeightInt8<2>::run_conv));
832 m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d_prepack"), TORCH_FN(QConv1dPackWeightInt8::run_conv));
833 m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_prepack"), TORCH_FN(QConvPackWeightInt8<2>::run_conv));
834 m.impl(TORCH_SELECTIVE_NAME("quantized::conv3d_prepack"), TORCH_FN(QConvPackWeightInt8<3>::run_conv));
835 // ConvTranspose
836 m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose1d_prepack"), TORCH_FN(QConv1dPackWeightInt8::run_deconv));
837 m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_prepack"), TORCH_FN(QConvPackWeightInt8<2>::run_deconv));
838 m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_prepack"), TORCH_FN(QConvPackWeightInt8<3>::run_deconv));
839 }
840
TORCH_LIBRARY_IMPL(_quantized,QuantizedCPU,m)841 TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) {
842 // Conv
843 m.impl(TORCH_SELECTIVE_NAME("_quantized::conv2d_prepack"), TORCH_FN(QConvPackWeightInt8<2>::run_conv));
844 m.impl(TORCH_SELECTIVE_NAME("_quantized::conv3d_prepack"), TORCH_FN(QConvPackWeightInt8<3>::run_conv));
845 // ConvTranspose
846 m.impl(TORCH_SELECTIVE_NAME("_quantized::conv_transpose1d_prepack"), TORCH_FN(QConv1dPackWeightInt8::run_deconv));
847 m.impl(TORCH_SELECTIVE_NAME("_quantized::conv_transpose2d_prepack"), TORCH_FN(QConvPackWeightInt8<2>::run_deconv));
848 m.impl(TORCH_SELECTIVE_NAME("_quantized::conv_transpose3d_prepack"), TORCH_FN(QConvPackWeightInt8<3>::run_deconv));
849 }
850
TORCH_LIBRARY_IMPL(onednn,CPU,m)851 TORCH_LIBRARY_IMPL(onednn, CPU, m) {
852 // New OP definition for Quantization in PyTorch 2.0 Export
853 // Conv Prepack
854 m.impl(TORCH_SELECTIVE_NAME("onednn::qconv_prepack"), TORCH_FN(QConvPrepackOneDNN::run_conv));
855 }
856
857 } // namespace
858 } // namespace native
859 } // namespace at
860