xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qconv.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <algorithm>
3 #include <cmath>
4 #include <string>
5 #include <vector>
6 
7 #include <ATen/core/Tensor.h>
8 #include <ATen/core/List.h>
9 #include <ATen/Context.h>
10 #include <ATen/Parallel.h>
11 #include <ATen/TensorOperators.h>
12 #include <ATen/SmallVector.h>
13 #include <ATen/native/quantized/PackedParams.h>
14 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
15 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
16 #include <ATen/native/quantized/cpu/XnnpackUtils.h>
17 #include <ATen/native/quantized/cpu/OnednnUtils.h>
18 #include <ATen/native/quantized/ConvUtils.h>
19 #include <ATen/native/quantized/cpu/QuantUtils.h>
20 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
21 #include <torch/library.h>
22 #include <ATen/quantized/Quantizer.h>
23 #include <ATen/native/mkldnn/MKLDNNCommon.h>
24 
25 #ifndef AT_PER_OPERATOR_HEADERS
26 #include <ATen/Functions.h>
27 #include <ATen/NativeFunctions.h>
28 #else
29 #include <ATen/ops/_empty_affine_quantized.h>
30 #include <ATen/ops/_empty_affine_quantized_native.h>
31 #include <ATen/ops/_empty_per_channel_affine_quantized_native.h>
32 #include <ATen/ops/empty.h>
33 #include <ATen/ops/quantize_per_channel_native.h>
34 #include <ATen/ops/quantize_per_tensor_native.h>
35 #include <ATen/ops/zeros.h>
36 #endif
37 
38 #include <c10/util/irange.h>
39 
40 namespace {
41 // To have a sanity check for maximum matrix size.
42 constexpr int64_t kReasonableMaxDim = 1000000;
43 } // namespace
44 
45 template <int kSpatialDim = 2>
ConvDimChecks(int64_t act_dims,int64_t stride_dims,int64_t padding_dims,int64_t output_padding_dims,int64_t dilation_dims,std::string func_name,bool transpose=false)46 bool ConvDimChecks(
47     int64_t act_dims,
48     int64_t stride_dims,
49     int64_t padding_dims,
50     int64_t output_padding_dims,
51     int64_t dilation_dims,
52     std::string func_name,
53     bool transpose = false) {
54   TORCH_CHECK(
55       act_dims == kSpatialDim + 2,
56       func_name,
57       kSpatialDim,
58       "d(): Expected activation tensor to have ",
59       kSpatialDim + 2,
60       " dimensions, got ",
61       act_dims);
62   TORCH_CHECK(
63       stride_dims == kSpatialDim,
64       func_name,
65       kSpatialDim,
66       "d(): Expected stride tensor to have ",
67       kSpatialDim,
68       " dimensions, got ",
69       stride_dims);
70   TORCH_CHECK(
71       padding_dims == kSpatialDim,
72       func_name,
73       kSpatialDim,
74       "d(): Expected padding tensor to have ",
75       kSpatialDim,
76       " dimensions, got ",
77       padding_dims);
78   TORCH_CHECK(
79       !transpose || (output_padding_dims == kSpatialDim),
80       func_name,
81       kSpatialDim,
82       "d(): Expected output padding tensor to have ",
83       kSpatialDim,
84       " dimensions, got ",
85       output_padding_dims);
86   TORCH_CHECK(
87       dilation_dims == kSpatialDim,
88       func_name,
89       kSpatialDim,
90       "d(): Expected dilation tensor to have ",
91       kSpatialDim,
92       " dimensions, got ",
93       dilation_dims);
94   return true;
95 }
96 
compute_deconv_shape(int64_t input,int64_t kernel,int64_t stride,int64_t input_padding,int64_t output_padding,int64_t dilation)97 inline int64_t compute_deconv_shape(int64_t input,
98                                     int64_t kernel,
99                                     int64_t stride,
100                                     int64_t input_padding,
101                                     int64_t output_padding,
102                                     int64_t dilation) {
103   int64_t out = (input - 1) * stride - 2 * input_padding
104                 + dilation * (kernel - 1) + output_padding + 1;
105   return out;
106 }
107 
108 template <int64_t kSpatialDim>
MakeDeConvOutputShape(int64_t N,int64_t M,const std::vector<int64_t> & input_shape,const std::vector<int64_t> & kernel,const torch::List<int64_t> & stride,const torch::List<int64_t> & input_padding,const torch::List<int64_t> & output_padding,const torch::List<int64_t> & dilation)109 at::SmallVector<int64_t, kSpatialDim + 2> MakeDeConvOutputShape(
110     int64_t N, int64_t M,
111     const std::vector<int64_t>& input_shape,
112     const std::vector<int64_t>& kernel,
113     const torch::List<int64_t>& stride,
114     const torch::List<int64_t>& input_padding,
115     const torch::List<int64_t>& output_padding,
116     const torch::List<int64_t>& dilation) {
117   at::SmallVector<int64_t, kSpatialDim + 2> output_shape;
118   output_shape.resize(kSpatialDim + 2);
119   output_shape[0] = N;  // Batch size
120   output_shape[1] = M;  // Output channels
121   for (const auto idx : c10::irange(kSpatialDim)) {
122     output_shape[idx + 2] = compute_deconv_shape(input_shape[idx],
123                                                  kernel[idx],
124                                                  stride[idx],
125                                                  input_padding[idx],
126                                                  output_padding[idx],
127                                                  dilation[idx]);
128     TORCH_CHECK(output_shape[idx + 2] > 0,
129                 "Output dimension is zero for ", idx, " axis;"
130                 " kernel: ", kernel[idx],
131                 ", stride: ", stride[idx],
132                 ", input padding: ", input_padding[idx],
133                 ", output padding: ", output_padding[idx],
134                 ", dilation: ", dilation[idx])
135     TORCH_CHECK(output_shape[idx + 2] < kReasonableMaxDim,
136                 "Output dimension is beyond reasonable maximum for ", idx,
137                 " axis;"
138                 " kernel: ", kernel[idx],
139                 ", stride: ", stride[idx],
140                 ", input padding: ", input_padding[idx],
141                 ", output padding: ", output_padding[idx],
142                 ", dilation: ", dilation[idx]);
143   }
144   return output_shape;
145 }
146 
147 #ifdef USE_FBGEMM
148 
149 template <int kSpatialDim = 2>
150 at::SmallVector<int64_t, kSpatialDim + 2> MakeConvOutputShape(
151     int N,
152     int M,
153     const std::array<int, kSpatialDim>& output_image_shape);
154 
155 template <>
MakeConvOutputShape(int N,int M,const std::array<int,2> & output_image_shape)156 at::SmallVector<int64_t, 4> MakeConvOutputShape<2>(
157     int N,
158     int M,
159     const std::array<int, 2>& output_image_shape) {
160   return {N, M, output_image_shape[0], output_image_shape[1]};
161 }
162 
163 template <>
MakeConvOutputShape(int N,int M,const std::array<int,3> & output_image_shape)164 at::SmallVector<int64_t, 5> MakeConvOutputShape<3>(
165     int N,
166     int M,
167     const std::array<int, 3>& output_image_shape) {
168   return {N,
169           M,
170           output_image_shape[0],
171           output_image_shape[1],
172           output_image_shape[2]};
173 }
174 
175 #endif // USE_FBGEMM
176 
177 #ifdef USE_PYTORCH_QNNPACK
178 
179 template <size_t kSpatialDim>
180 std::array<int64_t, kSpatialDim> MakeInputShape(
181     int64_t D,
182     int64_t H,
183     int64_t W);
184 
185 template <>
MakeInputShape(int64_t,int64_t H,int64_t W)186 std::array<int64_t, 2> MakeInputShape(int64_t /*D*/, int64_t H, int64_t W) {
187   return {H, W};
188 }
189 template <>
MakeInputShape(int64_t D,int64_t H,int64_t W)190 std::array<int64_t, 3> MakeInputShape(int64_t D, int64_t H, int64_t W) {
191   return {D, H, W};
192 }
193 
194 #endif // USE_PYTORCH_QNNPACK
195 
196 #ifdef USE_FBGEMM
197 template <int kSpatialDim>
GetBiasData(at::Tensor * bias_ptr)198 const float* PackedConvWeight<kSpatialDim>::GetBiasData(at::Tensor* bias_ptr) {
199   const float* bias_data = nullptr;
200   if (bias.has_value()) {
201     *bias_ptr = bias.value();
202     TORCH_CHECK(
203         bias_ptr->dtype() == at::kFloat,
204         "[QConv3D] The 'bias' tensor must have 'torch.float' dtype");
205     *bias_ptr = bias_ptr->contiguous();
206     TORCH_CHECK(bias_ptr->dim() == 1, "bias should be a vector (1D Tensor)");
207     const int M = w->outputChannels();
208     TORCH_CHECK(bias_ptr->size(0) == M, "bias should have ", M, " elements.");
209     bias_data = bias_ptr->data_ptr<float>();
210   }
211   return bias_data;
212 }
213 
214 template <int kSpatialDim>
GetQuantizationParams(float act_scale,float out_scale,std::vector<float> * output_multiplier_float,std::vector<float> * act_times_w_scale)215 void PackedConvWeight<kSpatialDim>::GetQuantizationParams(
216     float act_scale,
217     float out_scale,
218     std::vector<float>* output_multiplier_float,
219     std::vector<float>* act_times_w_scale) {
220   if (q_scheme == c10::kPerTensorAffine) {
221     *act_times_w_scale = {(act_scale * w_scale[0])};
222     *output_multiplier_float = {act_times_w_scale->front() / out_scale};
223   } else if (q_scheme == c10::kPerChannelAffine) {
224     const int M = w->outputChannels();
225     output_multiplier_float->resize(M);
226     act_times_w_scale->resize(M);
227     for (const auto i : c10::irange(M)) {
228       act_times_w_scale->at(i) = (act_scale * w_scale[i]);
229       output_multiplier_float->at(i) = act_times_w_scale->at(i) / out_scale;
230     }
231   } else {
232     TORCH_CHECK(false, "[QConv", kSpatialDim, "D] Unknown quantization scheme");
233   }
234 }
235 
236 template <int kSpatialDim>
apply(const at::Tensor & input,double output_scale,int64_t output_zero_point)237 at::Tensor PackedConvWeight<kSpatialDim>::apply(
238     const at::Tensor& input,
239     double output_scale,
240     int64_t output_zero_point) {
241   return apply_impl<false>(input, output_scale, output_zero_point);
242 }
243 
244 template <int kSpatialDim>
apply_relu(const at::Tensor & input,double output_scale,int64_t output_zero_point)245 at::Tensor PackedConvWeight<kSpatialDim>::apply_relu(
246     const at::Tensor& input,
247     double output_scale,
248     int64_t output_zero_point) {
249   return apply_impl<true>(input, output_scale, output_zero_point);
250 }
251 
252 template <int kSpatialDim>
253 template <bool kReluFused>
apply_impl(const at::Tensor & act,double output_scale,int64_t output_zero_point)254 at::Tensor PackedConvWeight<kSpatialDim>::apply_impl(
255     const at::Tensor& act,
256     double output_scale,
257     int64_t output_zero_point) {
258   // Quantized kernels are all written with NHWC (channels last) layout in
259   // mind. Ideally, we'd be compatible with conv2d behavior and preserve the
260   // inputs layout as is (doing necessary upconversions).
261   //
262   // However, to be more robust, for now we just force output layout to always
263   // be NHWC (channels last), thus opportunistically improving perf.
264   //
265   // This might change when full memory format support lands
266   // See https://github.com/pytorch/pytorch/issues/23403
267   const std::string func_name = transpose() ? "quantized::conv_transpose"
268                                             : "quantized::conv";
269   TORCH_CHECK(
270       fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
271   TORCH_CHECK(act.scalar_type() == c10::kQUInt8,
272                 func_name,
273                 "(FBGEMM): Expected activation data type ",
274                 toString(c10::kQUInt8),
275                 " but got ",
276                 toString(act.scalar_type()));
277 
278   ConvDimChecks<kSpatialDim>(
279       act.ndimension(), stride().size(), padding().size(),
280       output_padding().size(), dilation().size(), func_name, transpose());
281 
282   const int N = act.size(0);
283   const int C = act.size(1);
284   const int D = kSpatialDim == 2 ? 1 : act.size(2);
285   const int H = act.size(kSpatialDim);
286   const int W = act.size(kSpatialDim + 1);
287 
288   const at::Tensor act_ndhwc = kSpatialDim == 2
289       ? act.contiguous(c10::MemoryFormat::ChannelsLast)
290       : at::native::fbgemm_utils::ConvertToChannelsLast3dTensor(act);
291   const uint8_t* act_data =
292       reinterpret_cast<uint8_t*>(act_ndhwc.data_ptr<c10::quint8>());
293   auto* pack_w = w.get();
294 
295   const int M = pack_w->outputChannels();
296   const int kernel_d = kSpatialDim == 2 ? 1 : kernel[0];
297   const int kernel_h = kernel[kSpatialDim - 2];
298   const int kernel_w = kernel[kSpatialDim - 1];
299   const int pad_d = kSpatialDim == 2 ? 0 : padding_[0];
300   const int pad_h = padding_[kSpatialDim - 2];
301   const int pad_w = padding_[kSpatialDim - 1];
302   const int stride_d = kSpatialDim == 2 ? 1 : stride_[0];
303   const int stride_h = stride_[kSpatialDim - 2];
304   const int stride_w = stride_[kSpatialDim - 1];
305   const int dilation_d = kSpatialDim == 2 ? 1 : dilation_[0];
306   const int dilation_h = dilation_[kSpatialDim - 2];
307   const int dilation_w = dilation_[kSpatialDim - 1];
308   const int output_padding_d = kSpatialDim == 2 ? 0 : output_padding_[0];
309   const int output_padding_h = output_padding_[kSpatialDim - 2];
310   const int output_padding_w = output_padding_[kSpatialDim - 1];
311 
312   if (kSpatialDim == 2) {
313     TORCH_CHECK(
314         C == pack_w->inputChannels(),
315         "[QConv2D] Given groups=",
316         groups_,
317         ", weight of size ",
318         M,
319         ", ",
320         kernel_h,
321         ", ",
322         kernel_w,
323         ", ",
324         pack_w->inputChannels(),
325         ", expected input (NCHW) ",
326         N,
327         ", ",
328         C,
329         ", ",
330         H,
331         ", ",
332         W,
333         " to have ",
334         pack_w->inputChannels(),
335         " channels, but got ",
336         C,
337         " channels instead");
338   } else {
339     TORCH_CHECK(
340         C == pack_w->inputChannels(),
341         "[QConv3D] Given groups=",
342         groups_,
343         ", weight of size ",
344         M,
345         ", ",
346         kernel_d,
347         ", ",
348         kernel_h,
349         ", ",
350         kernel_w,
351         ", ",
352         pack_w->inputChannels(),
353         ", expected input (NCDHW) ",
354         N,
355         ", ",
356         C,
357         ", ",
358         D,
359         ", ",
360         H,
361         ", ",
362         W,
363         " to have ",
364         pack_w->inputChannels(),
365         " channels, but got ",
366         C,
367         " channels instead");
368   }
369 
370   fbgemm::conv_param_t<kSpatialDim> conv_p =
371       at::native::fbgemm_utils::MakeFbgemmConvParam<kSpatialDim>(
372           N, // Batch size
373           C, // Number of input channels
374           M, // Number of output channels
375           kSpatialDim == 2 ? std::vector<int>{H, W} : std::vector<int>{D, H, W},
376           groups_,
377           kSpatialDim == 2 ? std::vector<int>{kernel_h, kernel_w}
378                            : std::vector<int>{kernel_d, kernel_h, kernel_w},
379           kSpatialDim == 2 ? std::vector<int>{stride_h, stride_w}
380                            : std::vector<int>{stride_d, stride_h, stride_w},
381           kSpatialDim == 2 ? std::vector<int>{pad_h, pad_w}
382                            : std::vector<int>{pad_d, pad_h, pad_w},
383           kSpatialDim == 2
384               ? std::vector<int>{dilation_h, dilation_w}
385               : std::vector<int>{dilation_d, dilation_h, dilation_w},
386           kSpatialDim == 2
387               ? std::vector<int>{output_padding_h, output_padding_w}
388               : std::vector<int>{output_padding_d,
389                                  output_padding_h,
390                                  output_padding_w},
391           transpose());
392 
393   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
394   const float act_scale = act.q_scale();
395   const int32_t act_zero_point = act.q_zero_point();
396 
397   at::Tensor bias;
398   const float* bias_data = GetBiasData(&bias);
399 
400   TORCH_CHECK(
401       w_scale.size() == w_zp.size(),
402       "Weight scales and zero points vectors should have the same size.");
403   std::vector<float> output_multiplier_float;
404   std::vector<float> act_times_w_scale;
405   GetQuantizationParams(
406       act_scale, output_scale, &output_multiplier_float, &act_times_w_scale);
407 
408   at::SmallVector<int64_t, kSpatialDim + 2> output_shape;
409   if (transpose()) {
410     output_shape = MakeDeConvOutputShape<kSpatialDim>(
411         N,
412         M,
413         kSpatialDim == 2 ? std::vector<int64_t>{H, W} : std::vector<int64_t>{D, H, W},
414         kernel,
415         stride(),
416         padding(),
417         output_padding(),
418         dilation());
419 
420     // if use direct convolution implementation, compute the col_offsets
421     // of the weight matrix at model initialization stage.
422     // We need to know the shape of output matrix
423     // to compute col_offsets for direct convolution.
424     // Hence it cannot be called from inside weight packing function
425     // like other quantized conv implementation
426     if (pack_w->getPackedWForDirectconv().get() &&
427         pack_w->getPackedWForDirectconv().get()->is_first_call()) {
428           pack_w->getPackedWForDirectconv().get()->col_offsets_with_zero_pt_s8acc32_DirectConvT(
429               conv_p,
430               w_zp.data(),
431               col_offsets,
432               M);
433     }
434   } else {
435     output_shape = MakeConvOutputShape<kSpatialDim>(N, M, conv_p.OUT_DIM);
436   }
437   if (N > 0) {
438     TORCH_CHECK(
439         std::all_of(
440             output_shape.begin(),
441             output_shape.end(),
442             [](int64_t i) { return i > 0; }),
443         "[QConv",
444         kSpatialDim,
445         "D] each dimension of output tensor should be greater than 0");
446   }
447   at::Tensor output = kSpatialDim == 2
448       ? at::_empty_affine_quantized(
449             output_shape,
450             device(c10::kCPU)
451                 .dtype(c10::kQUInt8)
452                 .memory_format(c10::MemoryFormat::ChannelsLast),
453             output_scale,
454             output_zero_point,
455             std::nullopt)
456       : at::native::fbgemm_utils::MakeEmptyAffineQuantizedChannelsLast3dTensor(
457             output_shape[0],
458             output_shape[1],
459             output_shape[2],
460             output_shape[3],
461             output_shape[4],
462             device(c10::kCPU).dtype(c10::kQUInt8),
463             output_scale,
464             output_zero_point);
465   at::Tensor buffer =
466       at::empty(output.sizes(), output.options().dtype(c10::kInt));
467   const int num_tasks = at::get_num_threads();
468   at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
469     fbgemm::DoNothing<> kNoOpObj{};
470     for (const auto task_id : c10::irange(begin, end)) {
471       if (q_scheme == c10::kPerTensorAffine) {
472         fbgemm::ReQuantizeOutput<
473             kReluFused,
474             fbgemm::QuantizationGranularity::TENSOR,
475             float>
476             output_proc_obj(
477                 kNoOpObj,
478                 output_multiplier_float.data(),
479                 output_zero_point,
480                 act_zero_point,
481                 w_zp.data(),
482                 nullptr, /* row offset buffer */
483                 col_offsets.data(),
484                 bias_data,
485                 M,
486                 groups_,
487                 act_times_w_scale.data());
488         fbgemm::fbgemmConv<decltype(output_proc_obj), kSpatialDim, int32_t>(
489             conv_p,
490             act_data,
491             *pack_w,
492             reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
493             buffer.data_ptr<int32_t>(),
494             output_proc_obj,
495             task_id /* thread_id*/,
496             num_tasks /* num_threads */);
497       } else if (q_scheme == c10::kPerChannelAffine) {
498         fbgemm::ReQuantizeOutput<
499             kReluFused,
500             fbgemm::QuantizationGranularity::OUT_CHANNEL,
501             float>
502             output_proc_obj(
503                 kNoOpObj,
504                 output_multiplier_float.data(),
505                 output_zero_point,
506                 act_zero_point,
507                 w_zp.data(),
508                 nullptr, /* row offset buffer */
509                 col_offsets.data(),
510                 bias_data,
511                 M,
512                 groups_,
513                 act_times_w_scale.data());
514 
515         fbgemm::fbgemmConv<decltype(output_proc_obj), kSpatialDim, int32_t>(
516             conv_p,
517             act_data,
518             *pack_w,
519             reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
520             buffer.data_ptr<int32_t>(),
521             output_proc_obj,
522             task_id /* thread_id*/,
523             num_tasks /* num_threads */);
524       }
525     }
526   });
527 
528   return output;
529 }
530 
531 template at::Tensor PackedConvWeight<2>::apply(
532     const at::Tensor& act,
533     double output_scale,
534     int64_t output_zero_point);
535 
536 template at::Tensor PackedConvWeight<2>::apply_relu(
537     const at::Tensor& act,
538     double output_scale,
539     int64_t output_zero_point);
540 
541 template at::Tensor PackedConvWeight<3>::apply(
542     const at::Tensor& act,
543     double output_scale,
544     int64_t output_zero_point);
545 
546 template at::Tensor PackedConvWeight<3>::apply_relu(
547     const at::Tensor& act,
548     double output_scale,
549     int64_t output_zero_point);
550 
551 template at::Tensor PackedConvWeight<2>::apply_impl<false>(
552     const at::Tensor& act,
553     double output_scale,
554     int64_t output_zero_point);
555 
556 template at::Tensor PackedConvWeight<3>::apply_impl<false>(
557   const at::Tensor& act,
558   double output_scale,
559   int64_t output_zero_point);
560 
561 #endif // USE_FBGEMM
562 
563 #ifdef USE_PYTORCH_QNNPACK
564 
565 #ifdef USE_XNNPACK
566 template <int kSpatialDim>
567 template <typename scalar_t, bool kReluFused>
apply_impl_xnnp(const at::Tensor & act,double output_scale,int64_t output_zero_point)568 at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_impl_xnnp(
569     const at::Tensor& act, double output_scale, int64_t output_zero_point) {
570   using underlying_t = typename scalar_t::underlying;
571 
572   std::lock_guard<std::mutex> lock(qnnp_mutex_);
573 
574   const std::string func_name = transpose()
575       ? "quantized::conv_transpose (xnnpack)"
576       : "quantized::conv (xnnpack)";
577   TORCH_CHECK(
578       kSpatialDim == 2,
579       func_name, ": xnnpack does not currently support 3d convolution.");
580 
581   /*
582    * NB:
583    * [de]conv_prepack prepares weights (values, scale, and zero_points) ahead of
584    * time during prepack() call assuming the activation will be uint8_t. But it
585    * may not always be the case. A solution may involve making prepack routine
586    * aware of the input qdtype. But currently all the pieces are not ready to
587    * pass that model level info to the prepack function. So, for now, here in
588    * this function we have to massage weights if we learn the input qdtype is
589    * not uint8_t. This involves copying and converting uint8_t to int8_t
590    * whenever necessary. To add to that, since XNNPACK, as of writing this,
591    * doesn't support per_channel weights for quint8_t, we add following assert
592    * makes sure we don't run into that case. Also take shortcuts when processing
593    * weights, which means we have to revisit and fix some weight massging logic
594    * when we enable the missing feature in XNNPACK.
595    *
596    * Table below summarizes how the weights are handled,
597    *
598    * .-------------------------------------------------------------------------.
599    * | input_qdtype |              uint8_t            |            int8_t      |
600    * | per_channel  |       yes       |       no      |      yes     |    no   |
601    * |-------------------------------------------------------------------------|
602    * | zero_points  | at::zeros()*    | orig_zp + 128 | at:zeros()** | orig_zp |
603    * | scale        |            dtype = float, no changes needed              |
604    * | values       |        always processed before passing to XNNPACK        |
605    * .-------------------------------------------------------------------------.
606    *
607    * Notes: * - zero_points for uint8_t + per_channel: no support in xnnpack, need
608    * to fix when support is added. ** - zero_points for int8_t: symmetric
609    * quantization means XNNPACK will ignore kernel zero point(s).
610    */
611 
612   if constexpr (std::is_same_v<underlying_t, c10::quint8>) {
613     TORCH_CHECK(!per_channel(),
614       func_name, ": xnnpack does not currently have per_channel support with activation dtype of c10::quint8."
615     );
616   }
617 
618   // More checks
619   ConvDimChecks<kSpatialDim>(
620       act.ndimension(),
621       stride().size(),
622       padding().size(),
623       output_padding().size(),
624       dilation().size(),
625       func_name,
626       transpose());
627 
628   const int64_t N = act.size(0);
629   const int64_t H = act.size(2);
630   const int64_t W = act.size(3);
631   const int64_t D = 1;
632   const int64_t M = bias.size(0);
633 
634   const auto act_nhwc = act.contiguous(c10::MemoryFormat::ChannelsLast);
635   const auto act_input_scale = act_nhwc.q_scale();
636 
637   auto status = xnn_status_invalid_state;
638 
639   // Create an operator iff necessary
640   if (!xnnp_convolution_op ||
641       (!input_scale.has_value() || input_scale.value() != act_input_scale)) {
642     xnn_operator_t xnnp_op = nullptr;
643 
644     // Update the input scale so we may cache the op
645     input_scale = act_input_scale;
646 
647     // create an empty tensor for packing the weights
648     const at::Tensor weight_contig =
649         orig_weight.contiguous(c10::MemoryFormat::ChannelsLast);
650     const float* w_scales_data = w_scales.const_data_ptr<float>();
651     underlying_t w_zp = 0;
652     at::Tensor weight_tensor;
653 
654     if (!per_channel()) {
655       w_zp = static_cast<underlying_t>(
656           weight_contig.q_zero_point() +
657           (std::is_same<underlying_t, uint8_t>::value ? 128 : 0));
658 
659       weight_tensor = at::native::empty_affine_quantized(
660           weight_contig.sizes(),
661           c10::CppTypeToScalarType<scalar_t>::value,
662           std::nullopt /* layout */,
663           c10::kCPU,
664           std::nullopt /* pin_memory */,
665           w_scales_data[0],
666           w_zp,
667           c10::MemoryFormat::ChannelsLast);
668     } else { /* per_channel */
669       weight_tensor = at::native::empty_per_channel_affine_quantized(
670           weight_contig.sizes(),
671           w_scales,
672           at::zeros(w_scales.sizes(), at::kInt), /* see comment above about w_zp */
673           weight_contig.q_per_channel_axis(),
674           c10::CppTypeToScalarType<scalar_t>::value,
675           std::nullopt /* layout */,
676           c10::kCPU,
677           std::nullopt /* pin_memory */,
678           c10::MemoryFormat::ChannelsLast);
679     }
680 
681     // copy from the original weight and take care of dtype change if necessary
682     at::native::xnnp_utils::q8_copy_int8_weight_and_add_offset<scalar_t>(
683         weight_contig, weight_tensor);
684     const at::Tensor xnnp_weight =
685         at::native::xnnp_utils::convert_conv_weights_to_channel_last_tensor<
686             kSpatialDim>(weight_tensor, groups(), transpose());
687 
688     auto output_min = kReluFused
689         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
690         ? activationLimits<underlying_t>(output_scale, output_zero_point, Activation::RELU).first
691         : std::numeric_limits<underlying_t>::min();
692     auto output_max = kReluFused
693         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
694         ? activationLimits<underlying_t>(output_scale, output_zero_point, Activation::RELU).second
695         : std::numeric_limits<underlying_t>::max();
696 
697 
698     // Original bias was float, so we requantize it here.
699     at::Tensor qbias = quant_utils::QuantizeBias(per_channel(), bias, weight_contig, act_input_scale);
700 
701     status = at::native::xnnp_utils::xnnp_create_convolution2d_nhwc(
702         padding()[0],
703         padding()[1],
704         padding()[0],
705         padding()[1],
706         kernel_[0],
707         kernel_[1],
708         stride()[0],
709         stride()[1],
710         dilation()[0],
711         dilation()[1],
712         groups(),
713         !transpose() ? orig_weight.size(1) : orig_weight.size(0) / groups(),
714         !transpose() ? orig_weight.size(0) / groups() : orig_weight.size(1),
715         !transpose() ? orig_weight.size(1) * groups() : orig_weight.size(0),
716         !transpose() ? orig_weight.size(0) : orig_weight.size(1) * groups(),
717         act_nhwc.q_zero_point(),
718         act_input_scale,
719         w_zp, /* will be ignored for Q[SC]8, see comment
720                 above about w_zp*/
721         w_scales_data,
722         reinterpret_cast<const underlying_t*>(
723             xnnp_weight.template data_ptr<scalar_t>()),
724         reinterpret_cast<int32_t*>(qbias.template data_ptr<c10::qint32>()),
725         output_zero_point,
726         output_scale,
727         output_min,
728         output_max,
729         0,
730         &xnnp_op,
731         per_channel(),
732         transpose());
733 
734     xnnp_convolution_op = xnnpack_operator(xnnp_op);
735     TORCH_CHECK(
736         status == xnn_status_success,
737         func_name,
738         ": xnn create operator failed(",
739         status,
740         ")");
741   }
742 
743   at::SmallVector<int64_t, kSpatialDim + 2> output_shape;
744   const auto input_shape = MakeInputShape<kSpatialDim>(D, H, W);
745   if (transpose()) {
746     output_shape = MakeDeConvOutputShape<kSpatialDim>(
747         N, M, {H, W}, kernel_, stride(), padding(), output_padding(), dilation());
748   } else {
749     output_shape = at::native::quantized::MakeConvOutputShape<kSpatialDim>(
750         N, M, input_shape, kernel_, stride(), padding(), dilation());
751   }
752 
753   if (act_nhwc.numel() > 0) {
754     TORCH_CHECK(
755         std::all_of(
756             output_shape.begin(),
757             output_shape.end(),
758             [](int64_t i) { return i > 0; }),
759         func_name, ": ", kSpatialDim, "d (xnnpack): each dimension of output tensor should be greater than 0.")
760   }
761 
762   // Allocate output Tensor and a buffer for XNNPACK to use
763   at::Tensor output = at::native::empty_affine_quantized(
764       output_shape,
765       c10::CppTypeToScalarType<scalar_t>::value,
766       std::nullopt /* layout */,
767       c10::kCPU,
768       std::nullopt /* pin_memory */,
769       output_scale,
770       output_zero_point,
771       c10::MemoryFormat::ChannelsLast);
772 
773   // Reshape the operator
774   status = at::native::xnnp_utils::xnnp_reshape_convolution2d_nhwc(
775       xnnp_convolution_op.get(),
776       N,
777       H,
778       W,
779       caffe2::pthreadpool_(),
780       per_channel(),
781       transpose(),
782       output_padding()[0],
783       output_padding()[1]);
784 
785   TORCH_CHECK(
786       status == xnn_status_success,
787       func_name,
788       ": xnn setup operator failed(",
789       status,
790       ")");
791 
792   // Setup the operator
793   status = at::native::xnnp_utils::xnnp_setup_convolution2d_nhwc(
794       xnnp_convolution_op.get(),
795       reinterpret_cast<const underlying_t*>(act_nhwc.template data_ptr<scalar_t>()),
796       reinterpret_cast<underlying_t*>(output.template data_ptr<scalar_t>()),
797       per_channel(),
798       transpose());
799 
800   TORCH_CHECK(
801       status == xnn_status_success,
802       func_name,
803       ": xnn setup operator failed(",
804       status,
805       ")");
806 
807   // Run the operator
808   status = xnn_run_operator(
809       xnnp_convolution_op.get(), /* xnn_operator_t op */
810       caffe2::pthreadpool_()); /* pthreadpool_t threadpool */
811 
812   TORCH_CHECK(
813       status == xnn_status_success,
814       func_name,
815       ": xnn run operator failed(",
816       status,
817       ")");
818 
819   return output;
820 }
821 
822 #endif // USE_XNNPACK
823 
824 template <int kSpatialDim>
825 template <bool kReluFused>
apply_impl(const at::Tensor & act,double output_scale,int64_t output_zero_point)826 at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_impl(
827     const at::Tensor& act,
828     double output_scale,
829     int64_t output_zero_point) {
830   // QNNPack is not thread safe
831   std::lock_guard<std::mutex> lock(qnnp_mutex_);
832   const std::string func_name = transpose() ? "quantized::conv_transpose"
833                                             : "quantized::conv";
834   TORCH_CHECK(!(kReluFused && transpose()),
835               kSpatialDim == 2,
836               func_name, kSpatialDim,
837               "d (qnnpack): ConvTranspose cannot be fused with ReLU.");
838   TORCH_CHECK(act.scalar_type() == c10::kQUInt8,
839               func_name,
840               "(qnnpack): Expected activation data type ",
841               toString(c10::kQUInt8),
842               " but got ",
843               toString(act.scalar_type()));
844   ConvDimChecks<kSpatialDim>(
845       act.ndimension(), stride().size(), padding().size(),
846       output_padding().size(), dilation().size(), func_name, transpose());
847 
848   auto* pack_w = w.get();
849 
850   // TODO Can be replaced with packB->getOutputChannels() when update pre-pack
851   // to actually do the packing.
852   const int out_ch_idx = transpose() ? 1 : 0;
853   const auto out_ch = bias.size(0);
854   // inputs are in semantic NCHW format
855   const int N = act.size(0);
856   const int C = act.size(1);
857   const int D = kSpatialDim == 3 ? act.size(2) : 1;
858   const int H = act.size(kSpatialDim);
859   const int W = act.size(kSpatialDim + 1);
860   const int M = out_ch; // output channels
861 
862   const auto channels_last = kSpatialDim == 2
863       ? c10::MemoryFormat::ChannelsLast
864       : c10::MemoryFormat::ChannelsLast3d;
865   const at::Tensor act_ndhwc = act.contiguous(channels_last);
866 
867   auto output_min = kReluFused
868       // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
869       ? activationLimits<uint8_t>(output_scale, output_zero_point, Activation::RELU)
870             .first
871       : std::numeric_limits<uint8_t>::min();
872   auto output_max = kReluFused
873       // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
874       ? activationLimits<uint8_t>(output_scale, output_zero_point, Activation::RELU)
875             .second
876       : std::numeric_limits<uint8_t>::max();
877 
878   double act_input_scale = act_ndhwc.q_scale();
879 
880   // Re-quantizing the bias based on input scale and weight scale.
881   if (!input_scale.has_value() || input_scale.value() != act_input_scale) {
882     TORCH_CHECK(M == (transpose() ? groups() : 1) * orig_weight.size(out_ch_idx),
883         "Output channel size of weight and bias must match.");
884     TORCH_CHECK(C == (transpose() ? 1 : groups()) * orig_weight.size(1 - out_ch_idx),
885         "Input channel size of weight and bias must match.");
886 
887     // Get the original weight and adjust it to uint8 from int8
888     auto weight_contig = orig_weight.contiguous(channels_last);
889     auto bias_fp32 = bias;
890     int8_t* w_data =
891         reinterpret_cast<int8_t*>(weight_contig.template data_ptr<c10::qint8>());
892 
893     float* weight_scales_data = w_scales.data_ptr<float>();
894     // We calculate requant scale here as the vector holding the requant scale
895     // is owned by this module. The pointer is then passed to qnnpack backend.
896     generate_requantization_scales(
897         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
898         w_scales, act_input_scale, output_scale, requantization_scales);
899 
900     // TODO Kimish, we are allocating affine_quantized regardless of per channel or not.
901     // This allocation is actually used only for packing weight and thus will be freed.
902     // Still we should be consistent. Fix this.
903     at::Tensor qnnp_weight = at::_empty_affine_quantized(
904         weight_contig.sizes(),
905         at::device(c10::kCPU).dtype(c10::kQUInt8).memory_format(channels_last),
906         weight_scales_data[0],
907         w_zero_points[0],
908         std::nullopt);
909     auto* qnnp_w_data = qnnp_weight.template data_ptr<c10::quint8>();
910     auto wt_numel = weight_contig.numel();
911     for (const auto i : c10::irange(wt_numel)) {
912       qnnp_w_data[i] = static_cast<c10::quint8>(w_data[i] + 128);
913     }
914     // Original bias was float, so we requantize it here.
915     at::Tensor qbias = quant_utils::QuantizeBias(convolution_op->per_channel, bias_fp32, weight_contig, act_input_scale);
916 
917     // Update the input scale to not pack again.
918     input_scale = act_input_scale;
919     w.reset();
920     w = std::make_unique<qnnpack::PrePackConvWeights>(
921         convolution_op.get(),
922         w_zero_points.data(),
923         reinterpret_cast<uint8_t*>(qnnp_w_data),
924         reinterpret_cast<int32_t*>(qbias.template data_ptr<c10::qint32>()));
925     pack_w = w.get();
926     if (at::globalContext().releaseWeightsWhenPrepacking()) {
927         // On mobile, we release the original weight by resetting the intrusive_ptr.
928         // Calling unpack after this will throw an assertion.
929         orig_weight.reset();
930     }
931 
932     // Set padding buffer to zero point. This can only be done if we want
933     // to do it only once.
934     if (zero_buffer_size) {
935       memset(
936           convolution_op->zero_buffer,
937           act_ndhwc.q_zero_point(),
938           zero_buffer_size);
939     }
940   }
941 
942   TORCH_INTERNAL_ASSERT(pack_w != nullptr, "Packed Weights are NULL");
943   at::SmallVector<int64_t, kSpatialDim + 2> output_shape;
944   const auto input_shape = MakeInputShape<kSpatialDim>(D, H, W);
945   if (transpose()) {
946     output_shape = MakeDeConvOutputShape<kSpatialDim>(
947         N,
948         M,
949         kSpatialDim == 2 ? std::vector<int64_t>{H, W} : std::vector<int64_t>{D, H, W},
950         kernel_,
951         stride(),
952         padding(),
953         output_padding(),
954         dilation());
955   } else {
956     output_shape = at::native::quantized::MakeConvOutputShape<kSpatialDim>(
957         N, M, input_shape, kernel_, stride(), padding(), dilation());
958   }
959 
960   if (act_ndhwc.numel() > 0) {
961     TORCH_CHECK(
962         std::all_of(
963             output_shape.begin(),
964             output_shape.end(),
965             [](int64_t i) { return i > 0; }),
966         func_name,
967         kSpatialDim,
968         "d (qnnpack): each dimension of output tensor should "
969         "be greater than 0.")
970   }
971 
972   // Allocate output Tensor and a buffer for QNNPACK to use
973   at::Tensor output = at::native::empty_affine_quantized(
974       output_shape,
975       c10::kQUInt8,
976       std::nullopt /* layout */,
977       c10::kCPU,
978       std::nullopt /* pin_memory */,
979       output_scale,
980       output_zero_point,
981       channels_last);
982 
983   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
984   pytorch_qnnp_status run_status;
985   if (transpose()) {
986     run_status = qnnpack::qnnpackDeConv(
987         convolution_op.get(),
988         pack_w->getPackedWeights(),
989         N,
990         H,
991         W,
992         act_ndhwc.q_zero_point(),
993         reinterpret_cast<uint8_t*>(act_ndhwc.template data_ptr<c10::quint8>()),
994         w_zero_points.data(),
995         requantization_scales.data(),
996         output.q_zero_point(),
997         output_min,
998         output_max,
999         reinterpret_cast<uint8_t*>(output.template data_ptr<c10::quint8>()),
1000         caffe2::pthreadpool_());
1001   } else {
1002     run_status = qnnpack::qnnpackConv(
1003         convolution_op.get(),
1004         pack_w->getPackedWeights(),
1005         N,
1006         D,
1007         H,
1008         W,
1009         act_ndhwc.q_zero_point(),
1010         reinterpret_cast<uint8_t*>(act_ndhwc.template data_ptr<c10::quint8>()),
1011         w_zero_points.data(),
1012         requantization_scales.data(),
1013         output.q_zero_point(),
1014         output_min,
1015         output_max,
1016         reinterpret_cast<uint8_t*>(output.template data_ptr<c10::quint8>()),
1017         caffe2::pthreadpool_());
1018   }
1019 
1020   TORCH_INTERNAL_ASSERT(
1021       run_status == pytorch_qnnp_status_success,
1022       "failed to run quantized::conv2d (qnnpack) operator");
1023 
1024   return output;
1025 }
1026 
1027 #ifdef USE_XNNPACK
can_use_xnnp(c10::ScalarType dtype,int kSpatialDim,bool per_channel,bool transpose)1028 static bool can_use_xnnp(
1029     c10::ScalarType dtype,
1030     int kSpatialDim,
1031     bool per_channel,
1032     bool transpose) {
1033   if (!at::native::xnnpack::available()) {
1034     return false;
1035   }
1036   bool supported_dtypes = dtype == c10::kQInt8;
1037   bool invalid_config =
1038       (kSpatialDim != 2 /* No support for 3d convolution */
1039         || (dtype == c10::kQInt8 && transpose &&
1040             per_channel)); /* int8_t deconv does not support per-channel */
1041   if (supported_dtypes && invalid_config) {
1042     /* don't want this to fall through to QNNPACK */
1043     const std::string func_name =
1044         transpose ? "quantized::conv_transpose" : "quantized::conv";
1045     TORCH_CHECK(
1046         false,
1047         func_name,
1048         " (xnnpack): Unsupported conv config for dtype KQInt8");
1049   }
1050   return supported_dtypes && !invalid_config;
1051 }
1052 #endif  // USE_XNNPACK
1053 
1054 template <int kSpatialDim>
apply(const at::Tensor & input,double output_scale,int64_t output_zero_point)1055 at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply(
1056     const at::Tensor& input,
1057     double output_scale,
1058     int64_t output_zero_point) {
1059 #ifdef USE_XNNPACK
1060   if (can_use_xnnp(input.scalar_type(), kSpatialDim, per_channel(), transpose())) {
1061     return apply_impl_xnnp<c10::qint8, false>(
1062         input, output_scale, output_zero_point);
1063   } /* fall through for unsupported types, configs, or shapes */
1064 #endif // USE_XNNPACK
1065   return apply_impl<false>(input, output_scale, output_zero_point);
1066 }
1067 
1068 template <int kSpatialDim>
apply_relu(const at::Tensor & input,double output_scale,int64_t output_zero_point)1069 at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_relu(
1070     const at::Tensor& input,
1071     double output_scale,
1072     int64_t output_zero_point) {
1073 #ifdef USE_XNNPACK
1074   if (can_use_xnnp(input.scalar_type(), kSpatialDim, per_channel(), transpose())) {
1075     return apply_impl_xnnp<c10::qint8, true>(
1076         input, output_scale, output_zero_point);
1077   } /* fall through for unsupported types, configs, or shapes */
1078 #endif // USE_XNNPACK
1079   return apply_impl<true>(input, output_scale, output_zero_point);
1080 }
1081 
1082 template at::Tensor PackedConvWeightsQnnp<2>::apply(
1083     const at::Tensor& act,
1084     double output_scale,
1085     int64_t output_zero_point);
1086 
1087 template at::Tensor PackedConvWeightsQnnp<2>::apply_relu(
1088     const at::Tensor& act,
1089     double output_scale,
1090     int64_t output_zero_point);
1091 
1092 template at::Tensor PackedConvWeightsQnnp<3>::apply(
1093     const at::Tensor& act,
1094     double output_scale,
1095     int64_t output_zero_point);
1096 
1097 template at::Tensor PackedConvWeightsQnnp<3>::apply_relu(
1098     const at::Tensor& act,
1099     double output_scale,
1100     int64_t output_zero_point);
1101 
1102 template at::Tensor PackedConvWeightsQnnp<2>::apply_impl<false>(
1103     const at::Tensor& act,
1104     double output_scale,
1105     int64_t output_zero_point);
1106 
1107 template at::Tensor PackedConvWeightsQnnp<3>::apply_impl<false>(
1108   const at::Tensor& act,
1109   double output_scale,
1110   int64_t output_zero_point);
1111 
1112 #endif // USE_PYTORCH_QNNPACK
1113 
1114 #if AT_MKLDNN_ENABLED()
1115 template <int kSpatialDim>
apply(const at::Tensor & input,double output_scale,int64_t output_zero_point)1116 at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply(
1117     const at::Tensor& input,
1118     double output_scale,
1119     int64_t output_zero_point) {
1120   return apply_impl<false>(input, std::nullopt, output_scale, output_zero_point);
1121 }
1122 
1123 template <int kSpatialDim>
apply_relu(const at::Tensor & input,double output_scale,int64_t output_zero_point)1124 at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_relu(
1125     const at::Tensor& input,
1126     double output_scale,
1127     int64_t output_zero_point) {
1128   return apply_impl<true>(input, std::nullopt, output_scale, output_zero_point);
1129 }
1130 
1131 template <int kSpatialDim>
apply_add(const at::Tensor & input,const at::Tensor & accum,double output_scale,int64_t output_zero_point)1132 at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_add(
1133     const at::Tensor& input,
1134     const at::Tensor& accum,
1135     double output_scale,
1136     int64_t output_zero_point) {
1137   TORCH_CHECK(kSpatialDim == 2, " Currently, only conv2d with add is supported.");
1138   return apply_impl<false>(input, accum, output_scale, output_zero_point);
1139 }
1140 
1141 template <int kSpatialDim>
apply_add_relu(const at::Tensor & input,const at::Tensor & accum,double output_scale,int64_t output_zero_point)1142 at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_add_relu(
1143     const at::Tensor& input,
1144     const at::Tensor& accum,
1145     double output_scale,
1146     int64_t output_zero_point) {
1147   TORCH_CHECK(kSpatialDim == 2, " Currently, only conv2d add relu is supported.");
1148   return apply_impl<true>(input, accum, output_scale, output_zero_point);
1149 }
1150 
1151 template <int kSpatialDim>
1152 template <bool kReluFused>
apply_impl(const at::Tensor & act,const std::optional<at::Tensor> & accum,double output_scale,int64_t output_zero_point)1153 at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_impl(
1154     const at::Tensor& act,
1155     const std::optional<at::Tensor>& accum,
1156     double output_scale,
1157     int64_t output_zero_point) {
1158   std::string func_name = "quantized::conv";
1159   if (transpose()) {
1160     func_name += "_transpose";
1161   }
1162   func_name += std::to_string(kSpatialDim) + "d";
1163 
1164   // has_accum: extra input besides the conv to do conv add fusion.
1165   bool has_accum = accum.has_value() ? true : false;
1166   if (has_accum) {
1167     auto& ctx = at::globalContext();
1168     func_name += "_add";
1169     TORCH_CHECK(
1170       !transpose(),
1171       "Didn't support transposed conv for conv with add ",
1172       c10::toString(ctx.qEngine()));
1173   }
1174 
1175   if (kReluFused) {
1176     func_name += "_relu";
1177   }
1178   ConvDimChecks<kSpatialDim>(
1179       act.ndimension(), stride().size(), padding().size(),
1180       output_padding().size(), dilation().size(), func_name, transpose());
1181   TORCH_CHECK(act.scalar_type() == c10::ScalarType::QUInt8,
1182       func_name, " (ONEDNN): data type of input should be QUint8.");
1183 
1184   // src
1185   auto act_contig = act.contiguous(kSpatialDim == 2 ? c10::MemoryFormat::ChannelsLast : c10::MemoryFormat::ChannelsLast3d);
1186   auto src_dims = act_contig.sizes().vec();
1187   auto src_data_type = dnnl::memory::data_type::u8;
1188   auto src_desc = ideep::tensor::desc(src_dims, src_data_type,
1189       kSpatialDim == 2 ? ideep::format_tag::nhwc : ideep::format_tag::ndhwc);
1190   ideep::tensor src(src_desc, act_contig.data_ptr());
1191   // weights & bias
1192   ideep::tensor& weights = *(weight_.get());
1193   bool with_bias = bias_.has_value();
1194   const auto& kernel_size = weights.get_dims();
1195   // dst
1196   const std::vector<int64_t>& input_size = src.get_dims();
1197   std::vector<int64_t> output_sizes;
1198   if (transpose()) {
1199     // Prepacked weight format: [o, i, ...]
1200     const int N = act.size(0); // batch size
1201     const int C = act.size(1); // input channels
1202     const int M = weights.get_dim(0); // output channels
1203     const int D = kSpatialDim == 2 ? 1 : act.size(2); // input depth
1204     const int H = act.size(kSpatialDim); // input height
1205     const int W = act.size(kSpatialDim + 1); // input width
1206     const int KH = weights.get_dim(kSpatialDim); // kernel height
1207     const int KW = weights.get_dim(kSpatialDim + 1); // kernel width
1208     const int KD = kSpatialDim == 2 ? 1 : weights.get_dim(2); // kernel depth
1209     TORCH_CHECK(C == groups() * weights.get_dim(1), // weight: [o, i, ...]
1210                 func_name, " (ONEDNN): input channel number should be ",
1211                 groups() * weights.get_dim(1), ", but got ", C);
1212     auto output_shape = MakeDeConvOutputShape<kSpatialDim>(
1213         N,
1214         M,
1215         kSpatialDim == 2 ? std::vector<int64_t>{H, W} : std::vector<int64_t>{D, H, W},
1216         kSpatialDim == 2 ? std::vector<int64_t>{KH, KW} : std::vector<int64_t>{KD, KH, KW},
1217         stride(),
1218         padding(),
1219         output_padding(),
1220         dilation());
1221     output_sizes = c10::IntArrayRef(output_shape).vec();
1222   } else {
1223     output_sizes = at::native::conv_output_size(input_size, kernel_size, padding().vec(), stride().vec(), dilation().vec());
1224   }
1225   ideep::dims dst_dims = ideep::dims({output_sizes.cbegin(), output_sizes.cend()});
1226   at::Tensor output = at::_empty_affine_quantized(
1227       dst_dims,
1228       device(c10::kCPU)
1229           .dtype(c10::kQUInt8)
1230           .memory_format(kSpatialDim == 2 ?
1231               c10::MemoryFormat::ChannelsLast :
1232               c10::MemoryFormat::ChannelsLast3d),
1233       output_scale,
1234       output_zero_point,
1235       std::nullopt);
1236   if (output.numel() == 0) {
1237     return output;
1238   }
1239   ideep::tensor dst;
1240   at::Tensor accum_contig;
1241   if (has_accum) {
1242     auto dst_desc = ideep::tensor::desc(dst_dims, src_data_type,
1243         kSpatialDim == 2 ? ideep::format_tag::nhwc : ideep::format_tag::ndhwc);
1244     accum_contig = accum.value().contiguous(kSpatialDim == 2 ? c10::MemoryFormat::ChannelsLast : c10::MemoryFormat::ChannelsLast3d);
1245     TORCH_CHECK(accum_contig.dtype() == output.dtype(), "The output tensor should have same dtype as the accum tensor.");
1246     // When fused with sum, the dst tensor will share the data ptr as the accum tensor.
1247     dst.init(dst_desc, accum_contig.data_ptr());
1248   } else {
1249     dst = ideep::tensor({dst_dims, ideep::tensor::data_type::u8, {output.strides().cbegin(), output.strides().cend()}},
1250                       output.data_ptr());
1251   }
1252 
1253   // Parameters
1254   const ideep::dims& strides = stride().vec();
1255   const ideep::dims& dilates = dilation().vec();
1256   const ideep::dims& padding_l = padding().vec();
1257   const ideep::dims& padding_r = padding().vec();
1258   double input_scale = act.q_scale();
1259   int64_t input_zp = act.q_zero_point();
1260   // Scales of ONEDNN and PyTorch are reciprocal
1261   const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0/input_scale);
1262   const ideep::scale_t& weights_scales = weights.get_scale();
1263   double inv_output_scale = 1.0/output_scale;
1264   const ideep::zero_point_t src_zero_points = ideep::zero_point_t(1, input_zp);
1265   const ideep::zero_point_t dst_zero_points = ideep::zero_point_t(1, output_zero_point);
1266 
1267   ideep::attr_t op_attr;
1268   float sum_scale = has_accum ? accum.value().q_scale() : 1.0;
1269   int32_t sum_zero_point = has_accum ? accum.value().q_zero_point() : 0;
1270   if (has_accum) {
1271     // Just tells we have these post op, the actual value such as scale and zero point will be setted later.
1272     op_attr = kReluFused ? ideep::attr_t::residual_with_sum_zero_point() : ideep::attr_t::fuse_sum();
1273     const ideep::scale_t accum_scale = ideep::scale_t(1, 1.0/sum_scale);
1274     const ideep::zero_point_t accum_zero_points = ideep::zero_point_t(1, sum_zero_point);
1275     // Set the dst scale and zero point with the value of accum.
1276     // The true scale and zero point is stored in ideep::scale_t(scale_size, inv_output_scale) and dst_zero_points.
1277     dst.set_scale(accum_scale);
1278     dst.set_zero_point(accum_zero_points);
1279   } else if (kReluFused) {
1280     op_attr = ideep::attr_t::fuse_relu();
1281   }
1282 
1283   // Bias might be modified outside (e.g. by quantization bias correction).
1284   // If so, update the prepacked bias as well.
1285   if (with_bias && bias_.value().get_data_handle() != orig_bias_.value().data_ptr()) {
1286     bias_.value().init(bias_.value().get_desc(), orig_bias_.value().data_ptr());
1287   }
1288   const auto& b = with_bias ? bias_.value() : ideep::tensor();
1289   int num_threads = at::get_num_threads();
1290   if (transpose()) {
1291     // Primitive cache is initialized when called for the first time
1292     // and won't be updated afterwards.
1293     PrimitiveCacheKey cache_key = std::make_tuple(
1294         input_scale, input_zp, src_dims, output_scale, output_zero_point, num_threads, sum_scale, sum_zero_point);
1295     c10::call_once(*cache_initialized_flag, [&](){
1296         DeconvParams params;
1297         ideep::convolution_transpose_forward::prepare(
1298             params, src, weights, b, dst_dims, dst,
1299             strides, padding_l, padding_r, dilates, groups(),
1300             src_scales, weights_scales, ideep::scale_t(1, inv_output_scale),
1301             src_zero_points, dst_zero_points, op_attr,
1302             dnnl::algorithm::deconvolution_direct,
1303             dnnl::prop_kind::forward_inference,
1304             ideep::u8s8, ideep::engine::cpu_engine());
1305         get_deconv_cache() = DeconvPrimitiveCache(cache_key, params);
1306         auto expected_weight_desc = ideep::tensor::desc(params.pd.weights_desc(), groups());
1307         weights = weights.reorder_if_differ_in(expected_weight_desc);
1308     });
1309     if (get_deconv_cache().hit(cache_key)) {
1310       DeconvParams& params = get_deconv_cache().get_params();
1311       ideep::convolution_transpose_forward::compute<false, false>(
1312           params, src, weights, b, dst);
1313     } else {
1314       ideep::convolution_transpose_forward::compute(
1315           src, weights, b, dst_dims, dst,
1316           strides, padding_l, padding_r, dilates,
1317           groups(), src_scales, weights_scales,
1318           ideep::scale_t(1, inv_output_scale),
1319           src_zero_points, dst_zero_points, op_attr,
1320           dnnl::algorithm::deconvolution_direct,
1321           dnnl::prop_kind::forward_inference,
1322           ideep::u8s8, ideep::engine::cpu_engine());
1323     }
1324   } else {  // not transposed
1325     PrimitiveCacheKey cache_key = std::make_tuple(
1326         input_scale, input_zp, src_dims, output_scale, output_zero_point, num_threads, sum_scale, sum_zero_point);
1327     c10::call_once(*cache_initialized_flag, [&](){
1328         ConvParams params;
1329         ideep::convolution_forward::prepare(
1330             params, src, weights, b, dst_dims, dst,
1331             strides, dilates, padding_l, padding_r, groups(),
1332             src_scales, weights_scales, ideep::scale_t(1, inv_output_scale),
1333             src_zero_points, dst_zero_points,
1334             op_attr, dnnl::algorithm::convolution_direct,
1335             dnnl::prop_kind::forward_inference,
1336             ideep::u8s8, ideep::engine::cpu_engine());
1337         get_conv_cache() = ConvPrimitiveCache(cache_key, params);
1338         auto expected_weight_desc = ideep::tensor::desc(params.pd.weights_desc(), groups());
1339         weights = weights.reorder_if_differ_in(expected_weight_desc);
1340     });
1341     // If hit, use cached data. If miss, fall back to normal path.
1342     if (get_conv_cache().hit(cache_key)) {
1343       auto& params = get_conv_cache().get_params();
1344       ideep::convolution_forward::compute<false, false>(params, src, weights, b, dst);
1345     } else {
1346       ideep::convolution_forward::compute(
1347           src, weights, b, dst_dims, dst,
1348           strides, dilates, padding_l, padding_r, groups(),
1349           src_scales, weights_scales, ideep::scale_t(1, inv_output_scale),
1350           src_zero_points, dst_zero_points, op_attr,
1351           dnnl::algorithm::convolution_direct,
1352           dnnl::prop_kind::forward_inference,
1353           ideep::u8s8, ideep::engine::cpu_engine());
1354     }
1355   }
1356   if (has_accum) {
1357     // When fused with sum, the accum tensor share the data ptr as dst tensor as the output.
1358     // Reset output's scale and zero point into accum_contig.
1359     set_quantizer_(accum_contig, at::make_per_tensor_affine_quantizer(
1360         output_scale, output_zero_point, accum_contig.scalar_type()));
1361     return accum_contig;
1362   } else {
1363     return output;
1364   }
1365 }
1366 
1367 template at::Tensor PackedConvWeightsOnednn<2>::apply(
1368     const at::Tensor& act,
1369     double output_scale,
1370     int64_t output_zero_point);
1371 
1372 template at::Tensor PackedConvWeightsOnednn<2>::apply_relu(
1373     const at::Tensor& act,
1374     double output_scale,
1375     int64_t output_zero_point);
1376 
1377 template at::Tensor PackedConvWeightsOnednn<3>::apply(
1378     const at::Tensor& act,
1379     double output_scale,
1380     int64_t output_zero_point);
1381 
1382 template at::Tensor PackedConvWeightsOnednn<3>::apply_relu(
1383     const at::Tensor& act,
1384     double output_scale,
1385     int64_t output_zero_point);
1386 
_quantized_convolution_onednn(at::Tensor act,double act_scale,int64_t act_zero_point,at::Tensor weight,at::Tensor weight_scales,at::Tensor weight_zero_points,std::optional<at::Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> dilation,bool transposed,int64_t groups,double output_scale,int64_t output_zero_point,std::optional<at::Tensor> accum,double accum_scale,int64_t accum_zero_point,std::optional<c10::ScalarType> output_dtype,std::optional<c10::string_view> binary_attr,std::optional<at::Scalar> binary_alpha,std::optional<c10::string_view> unary_attr,torch::List<std::optional<at::Scalar>> unary_scalars,std::optional<c10::string_view> unary_algorithm)1387 static at::Tensor _quantized_convolution_onednn(
1388     at::Tensor act, // contains quantized values but not QTensor
1389     double act_scale,
1390     int64_t act_zero_point,
1391     at::Tensor weight, // MKLDNN tensor with quantized values
1392     at::Tensor weight_scales,
1393     at::Tensor weight_zero_points,
1394     std::optional<at::Tensor> bias, // Bias is not packed into MKLDNN tensor
1395     torch::List<int64_t> stride,
1396     torch::List<int64_t> padding,
1397     torch::List<int64_t> dilation,
1398     bool transposed,
1399     int64_t groups,
1400     double output_scale,
1401     int64_t output_zero_point,
1402     std::optional<at::Tensor> accum, // accum to fused with conv add
1403     double accum_scale,
1404     int64_t accum_zero_point,
1405     std::optional<c10::ScalarType> output_dtype,
1406     std::optional<c10::string_view> binary_attr,
1407     std::optional<at::Scalar> binary_alpha,
1408     std::optional<c10::string_view> unary_attr,
1409     torch::List<std::optional<at::Scalar>> unary_scalars,
1410     std::optional<c10::string_view> unary_algorithm) {
1411   /*********************************/
1412   /*          Checks               */
1413   /*********************************/
1414   // Due the constant folding inside Inductor freeze,
1415   // https://github.com/pytorch/pytorch/blob/b99d605a3070de35677cc43f0196c2f2e807b822/torch/ao/quantization/fx/_decomposed.py#L62-L63
1416   // inv_scale = 1.0 / scale will be folded.
1417   // So, we can only get inv_scale from quant node which is used as
1418   // output_scale of this op.
1419   bool fp32_output = output_dtype.has_value() && (output_dtype.value() == c10::kFloat);
1420   bool bfloat16_output = output_dtype.has_value() && (output_dtype.value() == c10::kBFloat16);
1421   if (fp32_output || bfloat16_output) {
1422     // When fp32 or bf16 output, oneDNN expects op_attr doesn't set_scales and set_zero_points.
1423     // So, we will use default output_scale as 1.0 and output_zero_point as 0, since
1424     // when output_scale is 1.0, we will skip invoking of op_attr.set_scales in ideep;
1425     // when output_zero_point is 0, we will skip invoking of op_attr.set_zero_points in ideep.
1426     TORCH_CHECK(output_scale == 1.0,  " (ONEDNN): fp32 or bf16 output, output_scale must be 1.0.");
1427     TORCH_CHECK(output_zero_point == 0,  " (ONEDNN): fp32 or bf16 output, output_zero_point must be 0");
1428   }
1429 
1430   int kSpatialDim = act.dim() - 2;
1431   bool is_1d = (1 == kSpatialDim);
1432 
1433   bool has_binary_post_op = binary_attr.has_value() && binary_attr.value() != "none";
1434   bool has_unary_post_op = unary_attr.has_value() && unary_attr.value() != "none";
1435   // has_accum_postop_sum: extra input besides the conv to do conv post op sum fusion.
1436   bool has_accum_postop_sum = has_binary_post_op && binary_attr.value() == "sum";
1437 
1438   if (has_accum_postop_sum) {
1439     TORCH_CHECK(accum.has_value(), "For post op sum, accum tensor should not be empty.");
1440     TORCH_CHECK(
1441       accum.value().is_contiguous(
1442         kSpatialDim == 2
1443         ? c10::MemoryFormat::ChannelsLast
1444         : c10::MemoryFormat::ChannelsLast3d
1445       ),
1446       "For post op sum, accum tensor must be contiguous."
1447     );
1448     if (fp32_output || bfloat16_output) {
1449       TORCH_CHECK(accum_scale == 1.0,  " (ONEDNN): fp32 or bf16 output, accum_scale must be 1.0.");
1450       TORCH_CHECK(accum_zero_point == 0,  " (ONEDNN): fp32 or bf16 output, accum_zero_point must be 0");
1451       TORCH_CHECK((accum.value().scalar_type() == c10::kFloat) || (accum.value().scalar_type() == c10::kBFloat16), "The accum tensor should be KFloat or KBFloat.");
1452     }
1453   }
1454 
1455   std::string func_name = "quantized::packed_weights_conv";
1456   func_name += std::to_string(kSpatialDim) + "d";
1457   if (has_binary_post_op) {
1458     func_name += binary_attr.value().data();
1459   }
1460   if (has_unary_post_op) {
1461     func_name += unary_attr.value().data();
1462   }
1463 
1464   if (kSpatialDim == 1) {
1465     kSpatialDim += 1;
1466   }
1467   TORCH_CHECK(
1468     weight.is_mkldnn(),
1469     func_name, ": Weight should be prepacked as an MKLDNN tensor"
1470   );
1471   if (transposed) {
1472     TORCH_CHECK(
1473       false,
1474       func_name, ": to support transposed convolution."
1475     );
1476   }
1477   if (is_1d) {
1478     // N, C, L -> N, C, 1, L
1479     act = act.unsqueeze(quant_utils::kConv1dSqueezeDim + 2);
1480     stride = quant_utils::MakeArgForConv1d(stride, 1);
1481     padding = quant_utils::MakeArgForConv1d(padding, 0);
1482     dilation = quant_utils::MakeArgForConv1d(dilation, 1);
1483   }
1484   TORCH_CHECK(
1485     act.scalar_type() == c10::ScalarType::Byte,
1486     func_name, ": Input tensor should have uint8 (unsigned char) data type");
1487   TORCH_CHECK(
1488     weight.scalar_type() == c10::ScalarType::Char,
1489     func_name, ": Weight tensor should have int8 (char) data type");
1490   TORCH_CHECK(
1491     weight.ndimension() == kSpatialDim + 2,
1492     func_name, ": Weights are expected to have ", kSpatialDim + 2, " dimensions");
1493   TORCH_CHECK(
1494     stride.size() == (decltype(stride.size()))kSpatialDim,
1495     func_name, ": stride should contain ", kSpatialDim, " elements for ",
1496     kSpatialDim, "D convolution.");
1497   TORCH_CHECK(
1498     padding.size() == (decltype(padding.size()))kSpatialDim,
1499     func_name, ": Specify front/top/left padding only. "
1500     "end/bottom/right padding assumed to be equal to front/top/left");
1501   TORCH_CHECK(
1502     dilation.size() == (decltype(dilation.size()))kSpatialDim,
1503     func_name, ": dilation should contain ", kSpatialDim, " elements for ",
1504     kSpatialDim, "D convolution.");
1505 
1506   // Parameters
1507 #if IDEEP_PREREQ(3, 1, 0, 1)
1508   // 1. If the weight scale generated by observer should with dtype float32
1509   // https://github.com/pytorch/pytorch/blob/d2c24eca8a60c56b31ca967a44d5cc4522802aa6/torch/ao/quantization/observer.py#L323
1510   // 2. If the weight scale got from the quantized tensor, like did in the UT. It's with dtype of double.
1511   // https://github.com/pytorch/pytorch/blob/d2fa3f608b5e4f582a8aaf752f10efe4ca72a7d0/aten/src/ATen/quantized/Quantizer.cpp#L69
1512   TORCH_CHECK(
1513     weight_scales.scalar_type() == c10::ScalarType::Double || weight_scales.scalar_type() == c10::ScalarType::Float,
1514     "weight_scales should be with data type Double or float");
1515   if (weight_scales.scalar_type() == c10::ScalarType::Double) {
1516     // For case 2, we will convert it from double to float, since ideep::scale_t is alias of std::vector<float>
1517     weight_scales = weight_scales.to(c10::ScalarType::Float);
1518   }
1519   TORCH_CHECK(
1520     weight_scales.ndimension() == 0 ||
1521     (weight_scales.strides().size() == 1 || weight_scales.stride(0) == 1),
1522     "weight_scales should be scalar tensor or contiguous 1D tensor.");
1523   ideep::scale_t weights_scales(weight_scales.data_ptr<float>(), weight_scales.data_ptr<float>()+weight_scales.numel());
1524 #elif IDEEP_PREREQ(3, 1, 0, 0)
1525   // TODO (leslie): optimize the performance here:
1526   // 1. Remove the reciprocal of weight scale, we have done the reciprocal of weight scale back in Ideep:
1527   // https://github.com/intel/ideep/blob/3c90e365526e19c110371d23831678a7e9d4353d/include/ideep/operators/conv.hpp#L163-L168
1528   // 2. Remove 2 memory copies of weight_scales:
1529   //   2.1 Input of weights_scales is PyTorch Dense tensor, we convert it to vector<float>
1530   //   2.2 OneDNN stream submit convert weights_scales from vector to ideep::tensor
1531   //   https://github.com/intel/ideep/blob/3c90e365526e19c110371d23831678a7e9d4353d/include/ideep/operators/conv.hpp#L1855-L1860
1532   // We should be able to directly convert weights_scales from PyTorch Dense Tensor to IDeep Tensor which can share same data ptr.
1533   ideep::scale_t weights_scales(weight_scales.numel());
1534   if (weight_scales.ndimension() == 0) {
1535     // Weight is quant per tensor, then weight_scales will be a scalar Tensor
1536     weights_scales[0] = 1.0 / weight_scales.item().toDouble(); // Scales of ONEDNN and PyTorch are reciprocal
1537   } else {
1538     // Weight is quant per channel
1539     for (int i = 0; i < weight_scales.numel(); ++i) {
1540       weights_scales[i] = 1.0 / weight_scales[i].item().toDouble();
1541     }
1542   }
1543 #else
1544   TORCH_CHECK(false, "Unexpected IDeep version to do qconv calculation.");
1545 #endif
1546 
1547   const ideep::zero_point_t src_zero_points = ideep::zero_point_t(1, act_zero_point);
1548   const ideep::zero_point_t dst_zero_points = ideep::zero_point_t(1, output_zero_point);
1549 
1550   // Weight
1551   auto packed_weight = at::native::itensor_from_mkldnn(weight);
1552 
1553   // Bias
1554   ideep::tensor onednn_bias;
1555   const int output_channels = weight.size(0);
1556   bool with_bias = bias.has_value();
1557 
1558   at::Tensor bias_val_float;
1559   if (with_bias) {
1560     // For int8-mixed-bf16, we will also use float32 bias
1561     bias_val_float = bias.value().to(at::kFloat);
1562     TORCH_CHECK(bias_val_float.dim() == 1, "bias should be a vector (1D Tensor)");
1563     TORCH_CHECK(
1564         bias_val_float.size(0) == output_channels,
1565         "bias should have K elements: " + std::to_string(output_channels));
1566     auto bias_desc = ideep::tensor::desc(bias_val_float.sizes().vec(), dnnl::memory::data_type::f32);
1567     onednn_bias.init(bias_desc, bias_val_float.data_ptr());
1568   }
1569 
1570   const auto& expected_bias = with_bias ? onednn_bias : ideep::tensor();
1571 
1572   /*********************************/
1573   /*        Computation            */
1574   /*********************************/
1575   // src
1576   auto act_contig = act.contiguous(kSpatialDim == 2 ?
1577                                    c10::MemoryFormat::ChannelsLast :
1578                                    c10::MemoryFormat::ChannelsLast3d);
1579   auto src_dims = act_contig.sizes().vec();
1580   auto src_data_type = dnnl::memory::data_type::u8;
1581   auto src_desc = ideep::tensor::desc(src_dims, src_data_type,
1582       kSpatialDim == 2 ? ideep::format_tag::nhwc : ideep::format_tag::ndhwc);
1583   ideep::tensor src;
1584   src.init(src_desc, act_contig.data_ptr());
1585   // dst
1586   const std::vector<int64_t>& input_size = src.get_dims();
1587   const auto& kernel_size = packed_weight.get_dims();
1588   std::vector<int64_t> output_sizes;
1589   output_sizes = at::native::conv_output_size(input_size, kernel_size, padding.vec(), stride.vec(), dilation.vec());
1590   ideep::dims dst_dims = ideep::dims({output_sizes.cbegin(), output_sizes.cend()});
1591   // Output is not a quantized tensor but data type is uint8
1592   at::Tensor output = has_accum_postop_sum ?
1593     accum.value() :
1594     at::empty(
1595       dst_dims,
1596       device(c10::kCPU)
1597           .dtype(fp32_output ? c10::kFloat : (bfloat16_output ? c10::kBFloat16 : c10::kByte))
1598           .memory_format(kSpatialDim == 2 ?
1599               c10::MemoryFormat::ChannelsLast :
1600               c10::MemoryFormat::ChannelsLast3d)
1601     );
1602   if (output.numel() == 0) {
1603     return output;
1604   }
1605   ideep::tensor dst = at::native::itensor_view_from_dense(output);
1606   static ideep::tensor::desc dummy_accum_desc;
1607   ideep::attr_t op_attr = onednn_utils::create_attr_by_post_op(
1608     binary_attr.has_value() ? binary_attr.value() : "none",
1609     binary_alpha.has_value() ? binary_alpha.value().to<double>() : 1.0,
1610     accum_scale,
1611     accum_zero_point,
1612     dummy_accum_desc,
1613     unary_attr.has_value() ? unary_attr.value() : "none",
1614     unary_scalars,
1615     unary_algorithm.has_value() ? unary_algorithm.value() : ""
1616   );
1617 
1618 #if IDEEP_PREREQ(3, 1, 0, 0)
1619   // Use oneDNN's APIs instead of prepare/compute from ideep to reduce integration overhead.
1620   // The functions from ideep are heavy because they have complex data structures for unified API
1621   // oneDNN version >= 3.1.0 is required.
1622   using ideep::tensor;
1623   auto weight_grouped = packed_weight.make_grouped_weights(groups, /* is_deconv */false);
1624   auto weights_desc = tensor::desc(weight_grouped.get_dims(), ideep::data_type::s8, ideep::format_tag::any);
1625   if (groups > 1) {
1626     weights_desc = weights_desc.to_grouped(groups);
1627   }
1628   auto dst_desc = dst.get_desc();
1629   auto bias_desc = with_bias ?
1630       tensor::desc(expected_bias.get_dims(), ideep::data_type::f32, ideep::format_tag::any) :
1631       tensor::desc();
1632   if (act_scale != 1.0f) {
1633     op_attr.set_scales_mask(DNNL_ARG_SRC, 0);
1634   }
1635   if (act_zero_point != 0) {
1636     op_attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
1637   }
1638   int oc_per_group = weight_grouped.get_dim(0) / groups;
1639   int wei_scale_mask = ideep::utils::conv_weight_scale_mask(weight_scales.numel(), oc_per_group, groups, false);
1640   op_attr.set_scales_mask(DNNL_ARG_WEIGHTS, wei_scale_mask);
1641   if (output_scale != 1.0f) {
1642     op_attr.set_scales_mask(DNNL_ARG_DST, 0);
1643   }
1644   if (output_zero_point != 0) {
1645     op_attr.set_zero_points_mask(DNNL_ARG_DST, 0);
1646   }
1647   op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
1648   auto engine = ideep::engine::cpu_engine();
1649   auto dilates_dnnl = ideep::utils::get_compatible_dilates(dilation.vec());
1650   auto primitive_desc = with_bias ?
1651       dnnl::convolution_forward::primitive_desc(
1652         engine, dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct,
1653         src_desc, weights_desc, bias_desc, dst_desc,
1654         stride.vec(), dilates_dnnl, padding.vec(), padding.vec(), op_attr
1655       ) :
1656       dnnl::convolution_forward::primitive_desc(
1657         engine, dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct,
1658         src_desc, weights_desc, dst_desc,
1659         stride.vec(), dilates_dnnl, padding.vec(), padding.vec(), op_attr
1660       );
1661   auto primitive = dnnl::convolution_forward(primitive_desc);
1662 
1663   // Reorder weight if needed
1664   auto expected_weight = weight_grouped.reorder_if_differ_in(primitive_desc.weights_desc());
1665 
1666   // Prepare args and execute primitive
1667   tensor scratchpad(primitive_desc.scratchpad_desc());
1668   ideep::exec_args args;
1669   args.insert({DNNL_ARG_SRC, src});
1670   args.insert({DNNL_ARG_WEIGHTS, expected_weight});
1671   args.insert({DNNL_ARG_DST, dst});
1672   args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
1673   if (with_bias) {
1674     args.insert({DNNL_ARG_BIAS, expected_bias});
1675   }
1676   tensor src_scales_t = tensor(ideep::scale_t(1, act_scale));
1677   tensor wei_scales_t = tensor(weights_scales);
1678   tensor dst_scales_t = tensor(ideep::scale_t(1, output_scale));
1679   tensor src_zp_t = tensor(ideep::zero_point_t(1, act_zero_point));
1680   tensor dst_zp_t = tensor(ideep::zero_point_t(1, output_zero_point));
1681   if (act_scale != 1.0f) {
1682     args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_t});
1683   }
1684   if (output_scale != 1.0f) {
1685     args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_t});
1686   }
1687   args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_t});
1688   if (act_zero_point != 0) {
1689     args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_t});
1690   }
1691   if (output_zero_point != 0) {
1692     args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_t});
1693   }
1694   primitive.execute(ideep::stream::default_stream(), args);
1695 #else
1696   // Scales of ONEDNN and PyTorch are reciprocal
1697   const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0 / act_scale);
1698 
1699   // set accum scale/zero point to dst
1700   if (has_accum_postop_sum) {
1701     const ideep::scale_t accum_ideep_scale = ideep::scale_t(1, 1.0/accum_scale);
1702     const ideep::zero_point_t accum_ideep_zero_points = ideep::zero_point_t(1, accum_zero_point);
1703     // Set the dst scale and zero point with the value of accum.
1704     // The true scale and zero point is stored in ideep::scale_t(scale_size, output_scale) and dst_zero_points.
1705     dst.set_scale(accum_ideep_scale);
1706     dst.set_zero_point(accum_ideep_zero_points);
1707   }
1708 
1709   // Weight Reorder
1710   ConvParams params;
1711   ideep::convolution_forward::prepare(
1712       params, src, packed_weight, expected_bias, dst_dims, dst,
1713       stride.vec(), dilation.vec(), padding.vec(), padding.vec(), groups,
1714       src_scales, weights_scales, ideep::scale_t(1, 1.0f / output_scale),
1715       src_zero_points, dst_zero_points,
1716       op_attr, dnnl::algorithm::convolution_direct,
1717       dnnl::prop_kind::forward_inference,
1718       ideep::u8s8, ideep::engine::cpu_engine());
1719   auto expected_weight_desc = ideep::tensor::desc(params.pd.weights_desc(), groups);
1720   ideep::tensor expected_weight = packed_weight.reorder_if_differ_in(expected_weight_desc);
1721 
1722   // Computation
1723   ideep::convolution_forward::compute<false, false>(params, src, expected_weight, expected_bias, dst);
1724 #endif
1725 
1726   if (is_1d) {
1727     output.squeeze_(quant_utils::kConv1dSqueezeDim + 2);
1728     return output;
1729   }
1730   if (has_accum_postop_sum) {
1731     return accum.value();
1732   } else {
1733     return output;
1734   }
1735 }
1736 
1737 #endif // #if AT_MKLDNN_ENABLED()
1738 
1739 namespace at::native {
1740 namespace {
1741 
1742 /*
1743  * FBGEMM uses vpmaddubsw instruction to multiply activations (uint8_t) and
1744  * weights (int8_t).
1745  *
1746  * https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_maddubs_epi16&expand=3284,3530
1747  *
1748  * vpmaddubsw operates on a vector of activations and a vector of
1749  * weights. If these vectors are
1750  *
1751  *    A (uint8_t) = a0, a1, a2, a3 ...
1752  *
1753  * and
1754  *
1755  *    B (int8_t)  = b0, b1, b2, b3 ...
1756  *
1757  * the result of this instruction is an int16_t vector with values
1758  *
1759  *    C (int16_t) = a0*b0 + a1*b1, a2*b2 + a3*b3 ...
1760  *
1761  * For large values of A and/or B the result (a0*b0 + a1*b1) might not fit into
1762  * an int16_t number. So the instruction saturates them to max (or min) possible
1763  * value of an int16_t number. Such behavior is expected for the
1764  * implementation below.
1765  *
1766  * For example, a0 = 255, a1 = 255, b0 = 127 and b1 = 127 the actual result
1767  * 64770 overflows for an int16_t number (-32768, 32767) so the returned result
1768  * is 32767.
1769  *
1770  */
1771 template <int kSpatialDim, bool kReluFused>
1772 class QConvInt8 final {
1773  public:
run(Tensor act,const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> & packed_weight,double output_scale,int64_t output_zero_point)1774   static Tensor run(
1775       Tensor act,
1776       const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& packed_weight,
1777       double output_scale,
1778       int64_t output_zero_point) {
1779     if (kReluFused) {
1780       return packed_weight->apply_relu(act, output_scale, output_zero_point);
1781     } else {
1782       return packed_weight->apply(act, output_scale, output_zero_point);
1783     }
1784   }
1785 };
1786 
1787 template <int kSpatialDim, bool kReluFused>
1788 class QConvAddInt8 final {
1789  public:
run(Tensor act,Tensor accum,const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> & packed_weight,double output_scale,int64_t output_zero_point)1790   static Tensor run(
1791       Tensor act,
1792       Tensor accum,
1793       const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& packed_weight,
1794       double output_scale,
1795       int64_t output_zero_point) {
1796 #if AT_MKLDNN_ENABLED() || !defined(STRIP_ERROR_MESSAGES)
1797     auto& ctx = at::globalContext();
1798 #endif
1799 #if AT_MKLDNN_ENABLED()
1800     if (ctx.qEngine() == at::QEngine::ONEDNN) {
1801       if (kReluFused) {
1802         return dynamic_cast<PackedConvWeightsOnednn<kSpatialDim>*>(packed_weight.get())->apply_add_relu(
1803           act, accum, output_scale, output_zero_point);
1804       } else {
1805         return dynamic_cast<PackedConvWeightsOnednn<kSpatialDim>*>(packed_weight.get())->apply_add(
1806           act, accum, output_scale, output_zero_point);
1807       }
1808     }
1809 #endif
1810     TORCH_CHECK(
1811     false,
1812     "Didn't find engine for operation quantized::conv2d_add.",
1813     toString(ctx.qEngine()));
1814   }
1815 };
1816 
1817 template <bool kReluFused>
1818 class QConv1dInt8 final {
1819  public:
run(Tensor act,const c10::intrusive_ptr<ConvPackedParamsBase<2>> & packed_weight,double output_scale,int64_t output_zero_point)1820   static Tensor run(
1821       Tensor act,
1822       const c10::intrusive_ptr<ConvPackedParamsBase<2>>& packed_weight,
1823       double output_scale,
1824       int64_t output_zero_point) {
1825     at::Tensor output;
1826     // N, C, L -> N, C, 1, L
1827     act = act.unsqueeze(quant_utils::kConv1dSqueezeDim + 2);
1828     if (kReluFused) {
1829       output = packed_weight->apply_relu(act, output_scale, output_zero_point);
1830     } else {
1831       output = packed_weight->apply(act, output_scale, output_zero_point);
1832     }
1833     // N, C, 1, L -> N, C, L
1834     return output.squeeze_(quant_utils::kConv1dSqueezeDim + 2);
1835   }
1836 };
1837 
1838 // kernel for maintaining backward compatibility
1839 template <int kSpatialDim, bool kReluFused>
1840 class QConvInt8ForBC final {
1841  public:
run(Tensor act,const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> & packed_weight,torch::List<int64_t>,torch::List<int64_t>,torch::List<int64_t>,int64_t,double output_scale,int64_t output_zero_point)1842   static Tensor run(
1843       Tensor act,
1844       const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& packed_weight,
1845       torch::List<int64_t> /*stride*/,
1846       torch::List<int64_t> /*padding*/,
1847       torch::List<int64_t> /*dilation*/,
1848       int64_t /*groups*/,
1849       double output_scale,
1850       int64_t output_zero_point) {
1851     if (kReluFused) {
1852       TORCH_WARN_ONCE(
1853           "Arguments [stride, padding, dilation, groups] in ops.quantized.conv" +
1854               std::to_string(kSpatialDim),
1855           "d_relu, have been removed, please update your model to remove these arguments.");
1856       return packed_weight->apply_relu(act, output_scale, output_zero_point);
1857     } else {
1858       TORCH_WARN_ONCE(
1859           "Arguments [stride, padding, dilation, groups] in ops.quantized.conv",
1860           std::to_string(kSpatialDim),
1861           "d, have been removed, please update your model to remove these arguments.");
1862       return packed_weight->apply(act, output_scale, output_zero_point);
1863     }
1864   }
1865 };
1866 
1867 class QConvoneDNN final {
1868  public:
run_pointwise(at::Tensor act,double act_scale,int64_t act_zero_point,at::Tensor weight,at::Tensor weight_scales,at::Tensor weight_zero_points,std::optional<at::Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> dilation,int64_t groups,double output_scale,int64_t output_zero_point,std::optional<c10::ScalarType> output_dtype,c10::string_view attr,torch::List<std::optional<at::Scalar>> scalars,std::optional<c10::string_view> algorithm)1869   static at::Tensor run_pointwise(
1870       at::Tensor act, // contains quantized values but not QTensor
1871       double act_scale,
1872       int64_t act_zero_point,
1873       at::Tensor weight, // contains quantized values but not QTensor
1874       at::Tensor weight_scales,
1875       at::Tensor weight_zero_points,
1876       std::optional<at::Tensor> bias,
1877       torch::List<int64_t> stride,
1878       torch::List<int64_t> padding,
1879       torch::List<int64_t> dilation,
1880       int64_t groups,
1881       double output_scale,
1882       int64_t output_zero_point,
1883       std::optional<c10::ScalarType> output_dtype,
1884       c10::string_view attr,
1885       torch::List<std::optional<at::Scalar>> scalars,
1886       std::optional<c10::string_view> algorithm) {
1887 #if AT_MKLDNN_ENABLED()
1888     if (act.dim() == 3 || act.dim() == 5) {
1889       // Conv1D/3D post op check
1890       TORCH_CHECK(
1891         attr == "none",
1892         "quantized pointwise conv",
1893         act.dim()-2,
1894         "d doesn't support unary_post_op fusion. Got unary_post_op: ",
1895         attr,
1896         ".")
1897     } else {
1898       // Conv2D post op check
1899       TORCH_CHECK(
1900         attr == "none" || attr == "relu" || attr == "hardtanh" || attr == "hardswish" || attr == "swish",
1901         "none post_op or post_op relu/hardtanh/hardswish is supported for quantized pointwise conv2d. Got unary_post_op: ",
1902         attr,
1903         ".")
1904     }
1905     return _quantized_convolution_onednn(
1906         act, act_scale, act_zero_point,
1907         weight, weight_scales, weight_zero_points,
1908         bias, stride, padding, dilation, /*transposed*/false,
1909         groups, output_scale, output_zero_point,
1910         /*accum*/std::nullopt, /*accum_scale*/0.0, /*accum_zero_point*/0,
1911         /*output_dtype*/output_dtype, /*binary_attr*/std::nullopt, /*binary_alpha*/std::nullopt,
1912         /*unary_attr*/attr, /*unary_scalars*/scalars, /*unary_algorithm*/algorithm
1913     );
1914 #else
1915     TORCH_CHECK(false, "Unimplemented as onednn is not available.")
1916 #endif
1917   }
run_pointwise_binary(at::Tensor act,double act_scale,int64_t act_zero_point,at::Tensor accum,double accum_scale,int64_t accum_zero_point,at::Tensor weight,at::Tensor weight_scales,at::Tensor weight_zero_points,std::optional<at::Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> dilation,int64_t groups,double output_scale,int64_t output_zero_point,std::optional<c10::ScalarType> output_dtype,c10::string_view binary_attr,std::optional<at::Scalar> alpha,std::optional<c10::string_view> unary_attr,torch::List<std::optional<at::Scalar>> unary_scalars,std::optional<c10::string_view> unary_algorithm)1918   static at::Tensor run_pointwise_binary(
1919       at::Tensor act, // contains quantized values but not QTensor
1920       double act_scale,
1921       int64_t act_zero_point,
1922       at::Tensor accum, // contains quantized values but not QTensor
1923       double accum_scale,
1924       int64_t accum_zero_point,
1925       at::Tensor weight, // contains quantized values but not QTensor
1926       at::Tensor weight_scales,
1927       at::Tensor weight_zero_points,
1928       std::optional<at::Tensor> bias,
1929       torch::List<int64_t> stride,
1930       torch::List<int64_t> padding,
1931       torch::List<int64_t> dilation,
1932       int64_t groups,
1933       double output_scale,
1934       int64_t output_zero_point,
1935       std::optional<c10::ScalarType> output_dtype,
1936       c10::string_view binary_attr,
1937       std::optional<at::Scalar> alpha,
1938       std::optional<c10::string_view> unary_attr,
1939       torch::List<std::optional<at::Scalar>> unary_scalars,
1940       std::optional<c10::string_view> unary_algorithm) {
1941 #if AT_MKLDNN_ENABLED()
1942     // Conv2D post op check
1943     TORCH_CHECK(
1944       act.dim() == 4 && binary_attr == "sum" && (
1945         !unary_attr.has_value() ||
1946         (unary_attr.has_value() &&
1947           (
1948             unary_attr.value() == "none" || unary_attr.value() == "relu"
1949           )
1950         )
1951       ),
1952       "post_op sum or post_op sum_relu is supported for quantized pointwise conv2d. Got binary_post_op: ",
1953       binary_attr,
1954       " unary_post_op: ",
1955       unary_attr.has_value() ? unary_attr.value() : "none",
1956       ".")
1957     return _quantized_convolution_onednn(
1958         act, act_scale, act_zero_point,
1959         weight, weight_scales, weight_zero_points,
1960         bias, stride, padding, dilation, /*transposed*/false,
1961         groups, output_scale, output_zero_point,
1962         accum, accum_scale, accum_zero_point,
1963         /*output_dtype*/output_dtype, binary_attr, alpha,
1964         unary_attr, unary_scalars, unary_algorithm
1965     );
1966 #else
1967     TORCH_CHECK(false, "Unimplemented as onednn is not available.")
1968 #endif
1969   }
1970 };
1971 
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)1972 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
1973   m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d"),          QConv1dInt8<false>::run);
1974   m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d_relu"),     QConv1dInt8<true>::run);
1975   m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d.new"),      QConvInt8<2, false>::run);
1976   m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_relu.new"), QConvInt8<2, true>::run);
1977   m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_add"),      QConvAddInt8<2, false>::run);
1978   m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_add_relu"), QConvAddInt8<2, true>::run);
1979   m.impl(TORCH_SELECTIVE_NAME("quantized::conv3d.new"),      QConvInt8<3, false>::run);
1980   m.impl(TORCH_SELECTIVE_NAME("quantized::conv3d_relu.new"), QConvInt8<3, true>::run);
1981   // for backward compatibility
1982   m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d"), QConvInt8ForBC<2, false>::run);
1983   m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_relu"), QConvInt8ForBC<2, true>::run);
1984   m.impl(TORCH_SELECTIVE_NAME("quantized::conv3d"), QConvInt8ForBC<3, false>::run);
1985   m.impl(TORCH_SELECTIVE_NAME("quantized::conv3d_relu"), QConvInt8ForBC<3, true>::run);
1986 
1987   // transpose
1988   m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose1d"),  QConv1dInt8<false>::run);
1989   m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d"),  QConvInt8<2, false>::run);
1990   m.impl(
1991       TORCH_SELECTIVE_NAME("quantized::conv_transpose3d"),
1992       QConvInt8<3, false>::run);
1993 }
1994 
TORCH_LIBRARY_IMPL(_quantized,QuantizedCPU,m)1995 TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) {
1996   m.impl(TORCH_SELECTIVE_NAME("_quantized::conv2d"),      QConvInt8<2, false>::run);
1997   m.impl(TORCH_SELECTIVE_NAME("_quantized::conv2d_relu"), QConvInt8<2, true>::run);
1998 
1999   // transpose
2000   m.impl(TORCH_SELECTIVE_NAME("_quantized::conv_transpose1d"),  QConv1dInt8<false>::run);
2001   m.impl(TORCH_SELECTIVE_NAME("_quantized::conv_transpose2d"),  QConvInt8<2, false>::run);
2002 }
2003 
TORCH_LIBRARY_IMPL(onednn,MkldnnCPU,m)2004 TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) {
2005   // Conv1D/2D/3D with unary postop
2006   m.impl(TORCH_SELECTIVE_NAME("onednn::qconv1d_pointwise"), QConvoneDNN::run_pointwise);
2007   m.impl(TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise"), QConvoneDNN::run_pointwise);
2008   m.impl(TORCH_SELECTIVE_NAME("onednn::qconv3d_pointwise"), QConvoneDNN::run_pointwise);
2009 
2010   // Conv2D with binary postop
2011   m.impl(TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise.binary"), QConvoneDNN::run_pointwise_binary);
2012 }
2013 
2014 } // namespace
2015 } // namespace at::native
2016