xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <utility>
3 #include <vector>
4 
5 #include <ATen/core/Tensor.h>
6 #include <ATen/core/List.h>
7 #include <ATen/Context.h>
8 #include <ATen/native/quantized/PackedParams.h>
9 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
10 #include <ATen/native/quantized/cpu/init_qnnpack.h>
11 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
12 #include <ATen/native/quantized/cpu/OnednnUtils.h>
13 #include <ATen/native/quantized/cpu/QuantUtils.h>
14 #include <torch/library.h>
15 #include <ATen/native/mkldnn/MKLDNNCommon.h>
16 
17 #ifndef AT_PER_OPERATOR_HEADERS
18 #include <ATen/Functions.h>
19 #else
20 #include <ATen/ops/zeros.h>
21 #endif
22 
23 #include <c10/util/irange.h>
24 
25 #ifdef USE_FBGEMM
26 template <int kSpatialDim>
27 c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeight<
28     kSpatialDim>::
prepack(at::Tensor weight,std::optional<at::Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> output_padding,torch::List<int64_t> dilation,int64_t groups,bool transpose)29     prepack(
30         at::Tensor weight,
31         std::optional<at::Tensor> bias,
32         torch::List<int64_t> stride,
33         torch::List<int64_t> padding,
34         torch::List<int64_t> output_padding,
35         torch::List<int64_t> dilation,
36         int64_t groups,
37         bool transpose) {
38   TORCH_CHECK(
39       weight.ndimension() == kSpatialDim + 2,
40       "Weights are expected to have ",
41       kSpatialDim + 2,
42       " dimensions");
43   TORCH_CHECK(
44       stride.size() == kSpatialDim,
45       "stride should contain ",
46       kSpatialDim,
47       " elements for ",
48       kSpatialDim,
49       "D convolution.");
50   TORCH_CHECK(
51       padding.size() == kSpatialDim,
52       "Specify front/top/left padding only. "
53       "end/bottom/right padding assumed to be equal to front/top/left");
54   TORCH_CHECK(
55       !transpose || output_padding.size() == kSpatialDim,
56       "quantized::conv_prepack: Specify top/left output padding "
57       "only. bottom/right padding assumed to be equal to top/left");
58   TORCH_CHECK(
59       dilation.size() == kSpatialDim,
60       "dilation should contain ",
61       kSpatialDim,
62       " elements for ",
63       kSpatialDim,
64       "D convolution.");
65   const int input_channels = transpose ? weight.size(0)
66                                        : weight.size(1) * groups;
67   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
68   const int output_channels = transpose ? weight.size(1) * groups
69                                         // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
70                                         : weight.size(0);
71   const int kernel_d = kSpatialDim == 2 ? 1 : weight.size(2);
72   const int kernel_h = weight.size(kSpatialDim);
73   const int kernel_w = weight.size(kSpatialDim + 1);
74 
75   // mini-batch doesn't have any impact on how we pack weights
76   // so we pass it as 1
77   // Input image height/width also don't have any impact on how we pack
78   // weights so we can pass any values
79   const fbgemm::conv_param_t<kSpatialDim> conv_p =
80       at::native::fbgemm_utils::MakeFbgemmConvParam<kSpatialDim>(
81           1, // dummy batch size
82           input_channels,
83           output_channels,
84           kSpatialDim == 2 ? std::vector<int>{28, 28} // dummy image size
85                            : std::vector<int>{28, 28, 28},
86           groups,
87           kSpatialDim == 2 ? std::vector<int>{kernel_h, kernel_w}
88                            : std::vector<int>{kernel_d, kernel_h, kernel_w},
89           std::vector<int>(stride.begin(), stride.end()),
90           std::vector<int>(padding.begin(), padding.end()),
91           std::vector<int>(dilation.begin(), dilation.end()),
92           std::vector<int>(output_padding.begin(), output_padding.end()),
93           transpose);
94 
95   const auto qtype = weight.qscheme();
96   std::vector<int32_t> zero_points;
97   if (qtype == c10::kPerTensorAffine) {
98     zero_points = {static_cast<int32_t>(weight.q_zero_point())};
99   } else if (qtype == c10::kPerChannelAffine) {
100     TORCH_CHECK(
101         !transpose,
102         "Per Channel Quantization is currently disabled for transposed conv");
103     zero_points.resize(output_channels);
104     for (const auto i : c10::irange(output_channels)) {
105       zero_points[i] = weight.q_per_channel_zero_points()[i].item<int32_t>();
106     }
107   } else {
108     TORCH_CHECK(false, "Unsupported qscheme: ", toString(qtype));
109   }
110 
111   // FBGEMM expects weights to be in channels last
112   // TODO: Change this when ChannelsLast3d is ready.
113   // FBGEMM needs G OC/G kDim0 ... kDimN IC/G
114   // for both conv and conv transpose
115   // but PyTorch lays them out as {out_c, in_c/groups, kH, kW}
116   // (or for ConvTranspose {in_c, out_c/groups, kH, kW})
117   const at::Tensor weight_nhwc =
118       at::native::fbgemm_utils::ConvertConvWeightsToChannelLastTensor<kSpatialDim>(weight, groups, transpose);
119   const int8_t* weight_data_int8 =
120           reinterpret_cast<int8_t*>(weight_nhwc.data_ptr<c10::qint8>());
121   std::vector<int32_t> col_offsets(output_channels);
122   // compute column offsets (Similar to
123   // fbgemm::col_offsets_with_zero_pt_s8acc32_ref) please note that offsets
124   // include the sum of columns as well as the scalar term weight_zero_point *
125   // KDim
126   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
127   const int input_channels_per_group = input_channels / groups;
128   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
129   const int output_channels_per_group = output_channels / groups;
130   const int inner_size =
131       kernel_d * kernel_h * kernel_w * input_channels_per_group;
132   for (const auto g : c10::irange(groups)) {
133     for (const auto i : c10::irange(output_channels_per_group)) {
134       // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
135       const int c = g * output_channels_per_group + i;
136       int32_t sum = 0;
137       for (const auto j : c10::irange(inner_size)) {
138         sum += static_cast<int32_t>(weight_data_int8[c * inner_size + j]);
139       }
140       if (qtype == c10::kPerTensorAffine) {
141         col_offsets[c] = sum - zero_points[0] * inner_size;
142       } else {
143         col_offsets[c] = sum - zero_points[c] * inner_size;
144       }
145     }
146   }
147 
148   std::vector<float> scales;
149   if (qtype == c10::kPerTensorAffine) {
150     scales = {static_cast<float>(weight.q_scale())};
151   } else if (qtype == c10::kPerChannelAffine) {
152     scales.resize(output_channels);
153     for (const auto i : c10::irange(output_channels)) {
154       scales[i] = weight.q_per_channel_scales()[i].item<float>();
155     }
156   }
157 
158   std::optional<at::Tensor> bias_contig;
159   if (bias.has_value()) {
160     at::Tensor bias_vec = bias.value();
161     TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)");
162     TORCH_CHECK(
163         bias_vec.size(0) == output_channels,
164         "bias should have K elements: " + std::to_string(output_channels));
165     bias_contig = bias->contiguous();
166   }
167 
168   auto ret_ptr = c10::make_intrusive<PackedConvWeight<kSpatialDim>>(
169       PackedConvWeight<kSpatialDim>{
170           std::make_unique<fbgemm::PackWeightsForConv<kSpatialDim>>(
171               conv_p, weight_data_int8),
172           bias_contig,
173           stride,
174           padding,
175           output_padding,
176           dilation,
177           groups,
178           transpose,
179           col_offsets,
180           kSpatialDim == 2 ? std::vector<int64_t>{kernel_h, kernel_w}
181                            : std::vector<int64_t>{kernel_d, kernel_h, kernel_w},
182           scales,
183           zero_points,
184           qtype});
185 
186   return ret_ptr;
187 }
188 
189 template struct PackedConvWeight<2>;
190 template struct PackedConvWeight<3>;
191 #endif // USE_FBGEMM
192 
193 #ifdef USE_PYTORCH_QNNPACK
194 template <int kSpatialDim>
195 c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightsQnnp<
196     kSpatialDim>::
prepack(at::Tensor weight,std::optional<at::Tensor> bias_in,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> output_padding,torch::List<int64_t> dilation,int64_t groups,bool transpose)197     prepack(
198         at::Tensor weight,
199         std::optional<at::Tensor> bias_in,
200         torch::List<int64_t> stride,
201         torch::List<int64_t> padding,
202         torch::List<int64_t> output_padding,
203         torch::List<int64_t> dilation,
204         int64_t groups,
205         bool transpose) {
206   TORCH_CHECK(
207       kSpatialDim == 2 || kSpatialDim == 3,  // 1D is packed as 2d, hence we don't need other checks
208       "QNNPACK packing only supports 2D / 3D convolution.");
209   TORCH_CHECK(
210       weight.ndimension() == kSpatialDim + 2,
211       "quantized::conv_prepack (qnnpack): Weights are expected to have ",
212       kSpatialDim + 2, " dimensions, found shape ", weight.sizes());
213   TORCH_CHECK(
214       stride.size() == kSpatialDim,
215       "quantized::conv_prepack (qnnpack): ",
216       kSpatialDim, "D convolution expects stride to have ",
217       kSpatialDim, " elements.");
218   TORCH_CHECK(
219       padding.size() == kSpatialDim,
220       "quantized::conv_prepack (qnnpack): Specify top/left input padding "
221       "only. bottom/right padding assumed to be equal to top/left");
222   TORCH_CHECK(
223       !transpose || output_padding.size() == kSpatialDim,
224       "quantized::conv_prepack (qnnpack): Specify top/left output padding "
225       "only. bottom/right padding assumed to be equal to top/left");
226   TORCH_CHECK(
227       dilation.size() == kSpatialDim,
228       "quantized::conv_prepack (qnnpack): ",
229       kSpatialDim, "D convolution expects dilation to have ",
230       kSpatialDim, " elements.");
231 
232   at::native::initQNNPACK();
233 
234   // QNNPACK expects weights to be of the format {out_c, kH, kW, in_c/groups},
235   // but PyTorch lays them out as {out_c, in_c/groups, kH, kW}
236   // (or for ConvTranspose {in_c, out_c/groups, kH, kW})
237   const auto out_ch = transpose ? weight.size(1) * groups : weight.size(0);
238   const uint32_t kernel_d = kSpatialDim == 3 ? weight.size(2) : 1;
239   const uint32_t kernel_h = weight.size(kSpatialDim);
240   const uint32_t kernel_w = weight.size(kSpatialDim + 1);
241 
242   at::Tensor bias_fp32;
243   if (bias_in.has_value()) {
244     bias_fp32 = bias_in.value();
245   } else {
246     bias_fp32 = at::zeros(out_ch, weight.options().dtype(at::kFloat));
247   }
248 
249   TORCH_CHECK(
250       !bias_fp32.defined() ||
251           (bias_fp32.ndimension() == 1 && bias_fp32.size(0) == out_ch),
252       "quantized::conv2d_prepack (qnnpack): expected bias to be 1-dimensional "
253       "with ",
254       out_ch,
255       " elements",
256       ", but got bias of size ",
257       bias_fp32.sizes(),
258       " instead. "
259       "(weight dimensions: ",
260       weight.sizes(), " , transpose: ",
261       (transpose ? "True)." : "False).")
262   );
263 
264   TORCH_CHECK(
265       !bias_fp32.defined() ||
266           (bias_fp32.ndimension() == 1 && bias_fp32.size(0) == out_ch),
267       "quantized::conv3d_prepack (qnnpack): expected bias to be 1-dimensional "
268       "with ",
269       out_ch,
270       " elements",
271       ", but got bias of size ",
272       bias_fp32.sizes(),
273       " instead. "
274       "(weight dimensions: ",
275       weight.sizes(), " , transpose: ",
276       (transpose ? "True)." : "False).")
277   );
278 
279   auto weight_contig = weight.contiguous(
280       kSpatialDim == 2 ? c10::MemoryFormat::ChannelsLast
281                        : c10::MemoryFormat::ChannelsLast3d);
282   const bool is_per_channel = weight_contig.qscheme() == at::kPerChannelAffine;
283   auto kernel_dim = kSpatialDim == 2
284       ? std::vector<int64_t>{kernel_h, kernel_w}
285       : std::vector<int64_t>{kernel_d, kernel_h, kernel_w};
286   auto [w_zero_points, w_scales] =
287       make_zero_points_and_scales_tensor(weight_contig, transpose, groups);
288   // We set the pre-packed conv weights to nullptr below as we call pre-pack
289   // during the first invocation of operator run. Refer to qconv.cpp for more
290   // details. TODO Update to actually call pre-pack here once bias is removed
291   // from pre-packing step.
292   auto ret_ptr = c10::intrusive_ptr<PackedConvWeightsQnnp<kSpatialDim>>::make(
293       nullptr, /* PrePackConvWeights */
294       weight_contig, /* int8_t weight */
295       bias_fp32.contiguous(), /* fp32 bias */
296       stride,
297       padding,
298       output_padding,
299       dilation,
300       groups,
301       transpose,
302       std::nullopt, /* input_scale */
303       kernel_dim,
304       w_scales,
305       std::move(w_zero_points),
306       is_per_channel);
307 
308   return ret_ptr;
309 }
310 
311 template
312 c10::intrusive_ptr<ConvPackedParamsBase<2>> PackedConvWeightsQnnp<
313     2>::
314     prepack(
315         at::Tensor weight,
316         std::optional<at::Tensor> bias_in,
317         torch::List<int64_t> stride,
318         torch::List<int64_t> padding,
319         torch::List<int64_t> output_padding,
320         torch::List<int64_t> dilation,
321         int64_t groups,
322         bool transpose);
323 #endif // USE_PYTORCH_QNNPACK
324 
325 #if AT_MKLDNN_ENABLED()
326 template <int kSpatialDim>
327 c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightsOnednn<
328     kSpatialDim>::
prepack(at::Tensor weight,std::optional<at::Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> output_padding,torch::List<int64_t> dilation,int64_t groups,bool transpose)329     prepack(
330         at::Tensor weight,
331         std::optional<at::Tensor> bias,
332         torch::List<int64_t> stride,
333         torch::List<int64_t> padding,
334         torch::List<int64_t> output_padding,
335         torch::List<int64_t> dilation,
336         int64_t groups,
337         bool transpose) {
338   TORCH_CHECK(
339       weight.ndimension() == kSpatialDim + 2,
340       "Weights are expected to have ", kSpatialDim + 2, " dimensions");
341   TORCH_CHECK(
342       stride.size() == kSpatialDim,
343       "stride should contain ", kSpatialDim, " elements for ",
344       kSpatialDim, "D convolution.");
345   TORCH_CHECK(
346       padding.size() == kSpatialDim,
347       "Specify front/top/left padding only. "
348       "end/bottom/right padding assumed to be equal to front/top/left");
349   TORCH_CHECK(
350       !transpose || output_padding.size() == kSpatialDim,
351       "quantized::conv_prepack: Specify top/left output padding "
352       "only. bottom/right padding assumed to be equal to top/left");
353   TORCH_CHECK(
354       dilation.size() == kSpatialDim,
355       "dilation should contain ", kSpatialDim, " elements for ",
356       kSpatialDim, "D convolution.");
357   TORCH_CHECK(
358       !transpose || std::all_of(output_padding.begin(), output_padding.end(), [](int i) { return i==0; }),
359       "quantized::conv_prepack: ONEDNN only supports zero output_padding.");
360 
361   // Weight
362   // Format: [OC IC//group KH KW] for conv; [IC OC//group KH KW] for deconv
363   auto dims = weight.sizes().vec();
364   auto strides = stride.vec();
365   auto padding_l = padding.vec();
366   auto padding_r = padding.vec();
367   auto dilates = dilation.vec();
368   auto op_attr = ideep::attr_t();
369   std::vector<int32_t> wgt_zero_points;
370   ideep::scale_t wgt_scales;
371   const int output_channels = transpose ? weight.size(1) * groups
372                                         : weight.size(0);
373   const auto qtype = weight.qscheme();
374   if (qtype == c10::kPerTensorAffine) {
375     TORCH_CHECK(
376         weight.q_zero_point()==0,
377         "quantized::qconv_prepack: ONEDNN only supports symmetric quantization of weight,"
378         " whose zero point must be 0.");
379     wgt_zero_points = std::vector<int32_t>(1, weight.q_zero_point());
380 #if IDEEP_PREREQ(3, 1, 0, 1)
381     wgt_scales = ideep::scale_t(1, weight.q_scale());
382 #elif IDEEP_PREREQ(3, 1, 0, 0)
383     wgt_scales = ideep::scale_t(1, 1.0/weight.q_scale()); // Scales of ONEDNN and PyTorch are reciprocal
384 #else
385     TORCH_CHECK(false, "Unexpected IDeep version to do qconv weight prepack.");
386 #endif
387   } else if (qtype == c10::kPerChannelAffine) {
388     TORCH_CHECK(
389         !transpose,
390         "Per Channel Quantization is currently disabled for transposed conv");
391     wgt_zero_points.resize(output_channels);
392     wgt_scales.resize(output_channels);
393     for (int i = 0; i < output_channels; ++i) {
394       wgt_zero_points[i] = weight.q_per_channel_zero_points()[i].item<int32_t>();
395       TORCH_CHECK(
396           wgt_zero_points[i]==0,
397           "quantized::qconv_prepack: ONEDNN only supports symmetric quantization of weight,"
398           " whose zero point must be 0.");
399 #if IDEEP_PREREQ(3, 1, 0, 1)
400       wgt_scales[i] = weight.q_per_channel_scales()[i].item<float>();
401 #elif IDEEP_PREREQ(3, 1, 0, 0)
402       wgt_scales[i] = 1.0f / weight.q_per_channel_scales()[i].item<float>(); // Scales of ONEDNN and PyTorch are reciprocal
403 #else
404       TORCH_CHECK(false, "Unexpected IDeep version to do qconv weight prepack.");
405 #endif
406     }
407   } else {
408     TORCH_CHECK(false, "Unsupported qscheme: ", toString(qtype));
409   }
410 
411   // Set runtime src zero point
412   op_attr.set_zero_points_mask(DNNL_ARG_SRC, /* zero_points_mask= */0);
413   at::Tensor weight_copy;
414   ideep::tensor::desc w_desc;
415   ideep::dims dims_iohw, dims_giohw;
416   ideep::tag w_tag = ideep::tag::any;
417   const bool with_groups = groups > 1;
418   if (transpose) {
419     // template args: <(src/dst) is_channels_last, transposed>
420     w_desc = ideep::convolution_transpose_forward::expected_weights_desc<true, false>(
421         dims, dnnl::memory::data_type::s8,
422         strides, padding_l, padding_r, dilates, groups,
423         dnnl::algorithm::deconvolution_direct, dnnl::prop_kind::forward_inference,
424         ideep::dims(), op_attr);
425     // convolution_transpose_forward::expected_weights_desc() gives format [i, o, ...],
426     // but ONEDNN requires [o, i, ...] for computation
427     dims_iohw = w_desc.get_dims();
428     dims_giohw = with_groups ? ideep::utils::group_dims(dims_iohw, groups) : dims_iohw;
429     std::vector<int64_t> perms(dims_giohw.size(), 0); // for permutation of weight
430     std::iota(perms.begin(), perms.end(), 0);
431     std::swap(perms[with_groups], perms[with_groups + 1]);
432     weight_copy = weight.reshape(dims_giohw).permute(c10::IntArrayRef(perms)).clone();
433   } else {
434     w_desc = ideep::convolution_forward::expected_weights_desc(
435         dims, dnnl::memory::data_type::s8,
436         strides, padding_l, padding_r, dilates, groups,
437         dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference,
438         dnnl::memory::data_type::u8, ideep::dims(), op_attr, /*is_channels_last=*/true);
439     weight_copy = weight.clone();
440   }
441   if (with_groups) {
442     w_tag = kSpatialDim == 2 ? ideep::tag::goihw : ideep::tag::goidhw;
443   } else {
444     w_tag = kSpatialDim == 2 ? ideep::tag::oihw : ideep::tag::oidhw;
445   }
446   ideep::dims w_dims = with_groups ? ideep::utils::group_dims(w_desc.get_dims(), groups)
447                                    : w_desc.get_dims();
448   ideep::tensor wgt = ideep::tensor(
449       ideep::tensor::desc({w_dims, dnnl::memory::data_type::s8, w_tag}, groups),
450       weight_copy.data_ptr());
451   wgt.set_scale(wgt_scales); // Scales are needed for feed_from().
452   ideep::tensor exp_wgt;
453   exp_wgt.init(w_desc);
454   exp_wgt.set_scale(wgt_scales); // Also for feed_from()
455   exp_wgt.feed_from(wgt, transpose); // expect wgt to be in [OC IC KH KW] format
456   ideep::tensor * packed_weight_p = new ideep::tensor(std::move(exp_wgt));
457   packed_weight_p->set_scale(wgt_scales);
458   packed_weight_p->set_zero_point(wgt_zero_points);
459   std::unique_ptr<ideep::tensor> weight_ptr(packed_weight_p);
460   // Bias
461   std::optional<ideep::tensor> onednn_bias{std::nullopt};
462   if (bias.has_value()) {
463     at::Tensor bias_vec = bias.value();
464     TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)");
465     TORCH_CHECK(
466         bias_vec.size(0) == output_channels,
467         "bias should have K elements: " + std::to_string(output_channels));
468     auto bias_desc = ideep::tensor::desc(bias.value().sizes().vec(), dnnl::memory::data_type::f32);
469     ideep::tensor packed_bias;
470     packed_bias.init(bias_desc, bias.value().data_ptr());
471     onednn_bias = std::optional<ideep::tensor>(packed_bias);
472   }
473   auto ret_ptr = c10::make_intrusive<PackedConvWeightsOnednn<kSpatialDim>>(
474       PackedConvWeightsOnednn<kSpatialDim>{
475         std::move(weight_ptr),
476         onednn_bias,
477         weight,
478         bias,
479         stride,
480         padding,
481         output_padding,
482         dilation,
483         groups,
484         transpose
485       });
486   return ret_ptr;
487 }
488 
489 template struct PackedConvWeightsOnednn<2>;
490 template struct PackedConvWeightsOnednn<3>;
491 
492 // Return the packed weight as Mkldnn Tensor
_qconv_prepack_onednn(at::Tensor weight,at::Tensor weight_scales,double input_scale,int64_t input_zero_point,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> dilation,int64_t groups,std::optional<torch::List<int64_t>> input_shape)493 at::Tensor _qconv_prepack_onednn(
494     at::Tensor weight, // from CPU backend instead of QuantizedCPU
495     at::Tensor weight_scales, // Weight zero points must be 0 for onednn
496     double input_scale,
497     int64_t input_zero_point,
498     torch::List<int64_t> stride,
499     torch::List<int64_t> padding,
500     torch::List<int64_t> dilation,
501     int64_t groups,
502     std::optional<torch::List<int64_t>> input_shape) {
503   int kSpatialDim = weight.ndimension() - 2;
504   TORCH_CHECK(
505       weight.ndimension() == kSpatialDim + 2,
506       "Weights are expected to have ", kSpatialDim + 2, " dimensions");
507   TORCH_CHECK(
508       stride.size() == (decltype(stride.size()))kSpatialDim,
509       "stride should contain ", kSpatialDim, " elements for ",
510       kSpatialDim, "D convolution.");
511   TORCH_CHECK(
512       padding.size() == (decltype(padding.size()))kSpatialDim,
513       "Specify front/top/left padding only. "
514       "end/bottom/right padding assumed to be equal to front/top/left");
515   TORCH_CHECK(
516       dilation.size() == (decltype(dilation.size()))kSpatialDim,
517       "dilation should contain ", kSpatialDim, " elements for ",
518       kSpatialDim, "D convolution.");
519 
520   bool is_1d = (1 == kSpatialDim);
521   auto x_dims = input_shape.has_value()?input_shape.value().vec():ideep::dims();
522   if (is_1d) {
523     if (input_shape.has_value()) {
524       // N, C, L -> N, C, 1, L
525       x_dims.insert(x_dims.begin() + 2, 1);
526     }
527     if (weight.dim() == 3) {
528       weight = weight.unsqueeze(quant_utils::kConv1dSqueezeDim + 2);
529     }
530     stride = quant_utils::MakeArgForConv1d(stride, 1);
531     padding = quant_utils::MakeArgForConv1d(padding, 0);
532     dilation = quant_utils::MakeArgForConv1d(dilation, 1);
533     kSpatialDim += 1;
534   }
535   auto w_dims = weight.sizes().vec();
536   auto strides = stride.vec();
537   auto padding_l = padding.vec();
538   auto padding_r = padding.vec();
539   auto dilates = dilation.vec();
540   auto op_attr = ideep::attr_t();
541 
542   ideep::scale_t weights_scales(weight_scales.numel());
543 
544   if (weight_scales.ndimension() == 0) {
545     // Weight is quant per tensor, then weight_scales will be a scalar Tensor
546     TORCH_CHECK(
547         weight_scales.numel() == 1,
548         "Weight is quant per tensor, weight scale expects 1 element but got ", weight_scales.numel(), " elements.");
549 #if IDEEP_PREREQ(3, 1, 0, 1)
550     weights_scales[0] = weight_scales.item().toDouble();
551 #elif IDEEP_PREREQ(3, 1, 0, 0)
552     weights_scales[0] = 1.0 / weight_scales.item().toDouble(); // Scales of ONEDNN and PyTorch are reciprocal
553 #else
554     TORCH_CHECK(false, "Unexpected IDeep version to do qconv weight prepack.");
555 #endif
556   } else {
557     // Weight is quant per channel
558     for (int i = 0; i < weight_scales.numel(); ++i) {
559 #if IDEEP_PREREQ(3, 1, 0, 1)
560       weights_scales[i] = weight_scales[i].item().toDouble();
561 #elif IDEEP_PREREQ(3, 1, 0, 0)
562       weights_scales[i] = 1.0 / weight_scales[i].item().toDouble();
563 #else
564       TORCH_CHECK(false, "Unexpected IDeep version to do qconv weight prepack.");
565 #endif
566     }
567   }
568 
569   if (input_scale != 1.0f) {
570     op_attr.set_scales_mask(DNNL_ARG_SRC, /* src_scales_mask= */0);
571   }
572   if (input_zero_point != 0) {
573     op_attr.set_zero_points_mask(DNNL_ARG_SRC, /* src_zero_points_mask= */0);
574   }
575 
576   at::Tensor weight_copy;
577   ideep::tensor::desc w_desc;
578   ideep::dims dims_iohw, dims_giohw;
579   ideep::tag w_tag = ideep::tag::any;
580   const bool with_groups = groups > 1;
581   w_desc = ideep::convolution_forward::expected_weights_desc(
582       w_dims, dnnl::memory::data_type::s8,
583       strides, padding_l, padding_r, dilates, groups,
584       dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference,
585       dnnl::memory::data_type::u8, x_dims, op_attr, /*is_channels_last=*/true);
586 
587   // Note: Weight in Conv1D will unsqueeze into Conv2D in previous step
588   weight_copy = weight.clone(c10::MemoryFormat::Contiguous);
589 
590   if (with_groups) {
591     w_tag = kSpatialDim == 2 ? ideep::tag::goihw : ideep::tag::goidhw;
592   } else {
593     w_tag = kSpatialDim == 2 ? ideep::tag::oihw : ideep::tag::oidhw;
594   }
595   ideep::dims wei_dims = with_groups ? ideep::utils::group_dims(w_desc.get_dims(), groups)
596                                   : w_desc.get_dims();
597   ideep::tensor wgt = ideep::tensor(
598       ideep::tensor::desc({wei_dims, dnnl::memory::data_type::s8, w_tag}, groups),
599       weight_copy.data_ptr());
600 
601   wgt.set_scale(weights_scales); // Scales are needed for feed_from().
602 
603   ideep::tensor exp_wgt;
604   exp_wgt.init(w_desc);
605   exp_wgt.set_scale(weights_scales); // Also for feed_from()
606   exp_wgt.feed_from(wgt, /*transposed*/false); // expect wgt to be in [OC IC KH KW] format
607 
608   auto packed_weight = at::native::new_with_itensor_mkldnn(
609       std::move(exp_wgt),
610       c10::optTypeMetaToScalarType(weight_copy.options().dtype_opt()),
611       weight_copy.options().device_opt());
612 
613   return packed_weight;
614 }
615 
616 #endif // #if AT_MKLDNN_ENABLED()
617 
618 namespace at {
619 namespace native {
620 namespace {
621 
622 template <int kSpatialDim = 2>
623 class QConvPackWeightInt8 final {
624  public:
run_conv(Tensor weight,std::optional<Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> dilation,int64_t groups)625   static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> run_conv(
626       Tensor weight,
627       std::optional<Tensor> bias,
628       torch::List<int64_t> stride,
629       torch::List<int64_t> padding,
630       torch::List<int64_t> dilation,
631       int64_t groups) {
632     torch::List<int64_t> output_padding;
633     output_padding.reserve(kSpatialDim);
634     for (C10_UNUSED const auto idx : c10::irange(kSpatialDim)) {
635       output_padding.push_back((int64_t)0);
636     }
637     return _run(weight, bias, stride, padding, output_padding, dilation, groups,
638                 /*transpose=*/false);
639   }
640 
run_deconv(Tensor weight,std::optional<Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> output_padding,torch::List<int64_t> dilation,int64_t groups)641   static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> run_deconv(
642       Tensor weight,
643       std::optional<Tensor> bias,
644       torch::List<int64_t> stride,
645       torch::List<int64_t> padding,
646       torch::List<int64_t> output_padding,
647       torch::List<int64_t> dilation,
648       int64_t groups) {
649     return _run(weight, bias, stride, padding, output_padding, dilation, groups,
650                 /*transpose=*/true);
651   }
652 
653  private:
_run(Tensor weight,std::optional<Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> output_padding,torch::List<int64_t> dilation,int64_t groups,bool transpose)654   static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> _run(
655       Tensor weight,
656       std::optional<Tensor> bias,
657       torch::List<int64_t> stride,
658       torch::List<int64_t> padding,
659       torch::List<int64_t> output_padding,
660       torch::List<int64_t> dilation,
661       int64_t groups,
662       bool transpose) {
663     auto& ctx = at::globalContext();
664 #ifdef USE_FBGEMM
665   if (ctx.qEngine() == at::QEngine::X86) {
666 #if AT_MKLDNN_ENABLED()
667     bool use_onednn = onednn_utils::should_use_onednn_quant(
668           weight, transpose, groups, output_padding);
669     if (use_onednn) {
670       return PackedConvWeightsOnednn<kSpatialDim>::prepack(
671           weight, bias, stride, padding, output_padding, dilation, groups, transpose);
672     }
673 #endif
674       return PackedConvWeight<kSpatialDim>::prepack(
675           weight, bias, stride, padding, output_padding, dilation, groups, transpose);
676   } // x86
677 #endif // defined(USE_FBGEMM) || AT_MKLDNN_ENABLED()
678 
679 #ifdef USE_FBGEMM
680     if (ctx.qEngine() == at::QEngine::FBGEMM) {
681       return PackedConvWeight<kSpatialDim>::prepack(
682           weight, bias, stride, padding, output_padding, dilation, groups,
683           transpose);
684     }
685 #endif
686 
687 #ifdef USE_PYTORCH_QNNPACK
688     if (ctx.qEngine() == at::QEngine::QNNPACK) {
689       return PackedConvWeightsQnnp<kSpatialDim>::prepack(
690           weight, bias, stride, padding, output_padding, dilation, groups,
691           transpose);
692     }
693 #endif
694 
695 #if AT_MKLDNN_ENABLED()
696     if (ctx.qEngine() == at::QEngine::ONEDNN) {
697       return PackedConvWeightsOnednn<kSpatialDim>::prepack(
698         weight, bias, stride, padding, output_padding, dilation, groups,
699             transpose);
700     }
701 #endif
702 
703     TORCH_CHECK(
704         false,
705         "Didn't find engine for operation quantized::conv2d_prepack ",
706         toString(ctx.qEngine()));
707   }
708 };
709 
710 
711 
712 class QConv1dPackWeightInt8 final {
713  public:
run_conv(Tensor weight,std::optional<Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> dilation,int64_t groups)714   static c10::intrusive_ptr<ConvPackedParamsBase<2>> run_conv(
715       Tensor weight,
716       std::optional<Tensor> bias,
717       torch::List<int64_t> stride,
718       torch::List<int64_t> padding,
719       torch::List<int64_t> dilation,
720       int64_t groups) {
721     const torch::List<int64_t> output_padding({0});
722     return _run(std::move(weight), std::move(bias), stride, padding, output_padding, dilation, groups,
723                 /*transpose=*/false);
724   }
725 
run_deconv(Tensor weight,std::optional<Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> output_padding,torch::List<int64_t> dilation,int64_t groups)726   static c10::intrusive_ptr<ConvPackedParamsBase<2>> run_deconv(
727       Tensor weight,
728       std::optional<Tensor> bias,
729       torch::List<int64_t> stride,
730       torch::List<int64_t> padding,
731       torch::List<int64_t> output_padding,
732       torch::List<int64_t> dilation,
733       int64_t groups) {
734     return _run(std::move(weight), std::move(bias), stride, padding, output_padding, dilation, groups,
735                 /*transpose=*/true);
736   }
737 
738  private:
_run(Tensor weight,std::optional<Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> output_padding,torch::List<int64_t> dilation,int64_t groups,bool transpose)739   static c10::intrusive_ptr<ConvPackedParamsBase<2>> _run(
740       Tensor weight,
741       std::optional<Tensor> bias,
742       torch::List<int64_t> stride,
743       torch::List<int64_t> padding,
744       torch::List<int64_t> output_padding,
745       torch::List<int64_t> dilation,
746       int64_t groups,
747       bool transpose) {
748     auto& ctx = at::globalContext();
749     if (weight.dim() == 3) {
750       weight = weight.unsqueeze(quant_utils::kConv1dSqueezeDim + 2);
751     }
752     stride = quant_utils::MakeArgForConv1d(stride, 1);
753     padding = quant_utils::MakeArgForConv1d(padding, 0);
754     output_padding = quant_utils::MakeArgForConv1d(output_padding, 0);
755     dilation = quant_utils::MakeArgForConv1d(dilation, 1);
756 
757 #ifdef USE_FBGEMM
758   if (ctx.qEngine() == at::QEngine::X86) {
759 #if AT_MKLDNN_ENABLED()
760     bool use_onednn = onednn_utils::should_use_onednn_quant(
761         weight, transpose, groups, output_padding);
762     if (use_onednn) {
763       return PackedConvWeightsOnednn<2>::prepack(
764           weight, bias, stride, padding, output_padding, dilation, groups,
765           transpose);
766     }
767 #endif
768     return PackedConvWeight<2>::prepack(
769         std::move(weight), std::move(bias), stride, padding, output_padding, dilation, groups,
770         transpose);
771 
772   } // x86
773 #endif
774 
775 #ifdef USE_FBGEMM
776     if (ctx.qEngine() == at::QEngine::FBGEMM) {
777       return PackedConvWeight<2>::prepack(
778           std::move(weight), std::move(bias), stride, padding, output_padding, dilation, groups,
779           transpose);
780     }
781 #endif
782 
783 #ifdef USE_PYTORCH_QNNPACK
784     if (ctx.qEngine() == at::QEngine::QNNPACK) {
785       return PackedConvWeightsQnnp<2>::prepack(
786           std::move(weight), std::move(bias), stride, padding, output_padding, dilation, groups,
787           transpose);
788     }
789 #endif
790 
791 #if AT_MKLDNN_ENABLED()
792     if (ctx.qEngine() == at::QEngine::ONEDNN) {
793       return PackedConvWeightsOnednn<2>::prepack(
794           weight, bias, stride, padding, output_padding, dilation, groups,
795           transpose);
796     }
797 #endif
798 
799     TORCH_CHECK(
800         false,
801         "Didn't find engine for operation quantized::conv1d_prepack ",
802         toString(ctx.qEngine()));
803   }
804 };
805 
806 class QConvPrepackOneDNN final {
807  public:
run_conv(at::Tensor weight,at::Tensor weight_scales,double input_scale,int64_t input_zero_point,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> dilation,int64_t groups,std::optional<torch::List<int64_t>> input_shape)808   static at::Tensor run_conv(
809     at::Tensor weight, // from CPU backend instead of QuantizedCPU
810     at::Tensor weight_scales, // Weight zero points must be 0s for onednn
811     double input_scale,
812     int64_t input_zero_point,
813     torch::List<int64_t> stride,
814     torch::List<int64_t> padding,
815     torch::List<int64_t> dilation,
816     int64_t groups,
817     std::optional<torch::List<int64_t>> input_shape) {
818 #if AT_MKLDNN_ENABLED()
819     return _qconv_prepack_onednn(
820         weight, weight_scales, input_scale, input_zero_point,
821         stride, padding, dilation, groups, input_shape);
822 #else
823     TORCH_CHECK(false, "Unimplemented as onednn is not available.")
824 #endif
825   }
826 };
827 
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)828 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
829   // Conv
830   // conv_prepack is deprecated, please use conv2d_prepack for 2D conv.
831   m.impl(TORCH_SELECTIVE_NAME("quantized::conv_prepack"), TORCH_FN(QConvPackWeightInt8<2>::run_conv));
832   m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d_prepack"), TORCH_FN(QConv1dPackWeightInt8::run_conv));
833   m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_prepack"), TORCH_FN(QConvPackWeightInt8<2>::run_conv));
834   m.impl(TORCH_SELECTIVE_NAME("quantized::conv3d_prepack"), TORCH_FN(QConvPackWeightInt8<3>::run_conv));
835   // ConvTranspose
836   m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose1d_prepack"), TORCH_FN(QConv1dPackWeightInt8::run_deconv));
837   m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_prepack"), TORCH_FN(QConvPackWeightInt8<2>::run_deconv));
838   m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_prepack"), TORCH_FN(QConvPackWeightInt8<3>::run_deconv));
839 }
840 
TORCH_LIBRARY_IMPL(_quantized,QuantizedCPU,m)841 TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) {
842   // Conv
843   m.impl(TORCH_SELECTIVE_NAME("_quantized::conv2d_prepack"), TORCH_FN(QConvPackWeightInt8<2>::run_conv));
844   m.impl(TORCH_SELECTIVE_NAME("_quantized::conv3d_prepack"), TORCH_FN(QConvPackWeightInt8<3>::run_conv));
845   // ConvTranspose
846   m.impl(TORCH_SELECTIVE_NAME("_quantized::conv_transpose1d_prepack"), TORCH_FN(QConv1dPackWeightInt8::run_deconv));
847   m.impl(TORCH_SELECTIVE_NAME("_quantized::conv_transpose2d_prepack"), TORCH_FN(QConvPackWeightInt8<2>::run_deconv));
848   m.impl(TORCH_SELECTIVE_NAME("_quantized::conv_transpose3d_prepack"), TORCH_FN(QConvPackWeightInt8<3>::run_deconv));
849 }
850 
TORCH_LIBRARY_IMPL(onednn,CPU,m)851 TORCH_LIBRARY_IMPL(onednn, CPU, m) {
852   // New OP definition for Quantization in PyTorch 2.0 Export
853   // Conv Prepack
854   m.impl(TORCH_SELECTIVE_NAME("onednn::qconv_prepack"), TORCH_FN(QConvPrepackOneDNN::run_conv));
855 }
856 
857 } // namespace
858 } // namespace native
859 } // namespace at
860