xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qlinear.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Context.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorOperators.h>
6 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
7 #include <ATen/native/quantized/PackedParams.h>
8 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
9 #include <ATen/native/quantized/cpu/XnnpackUtils.h>
10 #include <ATen/native/quantized/cpu/OnednnUtils.h>
11 #include <ATen/native/quantized/cpu/QuantUtils.h>
12 #include <ATen/native/mkldnn/MKLDNNCommon.h>
13 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
14 #include <torch/library.h>
15 
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #include <ATen/NativeFunctions.h>
19 #else
20 #include <ATen/ops/_empty_affine_quantized.h>         // for _empty_affine_q...
21 #include <ATen/ops/_empty_affine_quantized_native.h>  // for empty_affine_qu...
22 #include <ATen/ops/empty.h>                           // for empty
23 #include <ATen/ops/quantize_per_channel_native.h>     // for quantize_per_ch...
24 #include <ATen/ops/quantize_per_tensor_native.h>      // for quantize_per_te...
25 #include <ATen/ops/zeros.h>
26 #endif
27 
28 #include <c10/util/irange.h>
29 
30 #include <algorithm>
31 #include <string>
32 
33 int register_linear_params();
34 
35 #ifdef USE_FBGEMM
36 template <bool ReluFused>
apply_impl(const at::Tensor & input,double output_scale,int64_t output_zero_point,at::Tensor & output)37 at::Tensor& PackedLinearWeight::apply_impl(
38     const at::Tensor& input,
39     double output_scale,
40     int64_t output_zero_point,
41     at::Tensor& output) {
42   // uint8 * int8 -> uint8 (no quantization/dequantization)
43 
44   // We make a strong guarantee that models using these operators will have
45   // the same numerics across different machines. Therefore, we do not provide
46   // a fallback path and rather fail loudly if we cannot run FBGEMM.
47   TORCH_CHECK(
48       fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
49   TORCH_CHECK(input.scalar_type() == c10::kQUInt8,
50                 "Expected input data type ",
51                 toString(c10::kQUInt8),
52                 " but got ",
53                 toString(input.scalar_type()));
54 
55   // TODO: contiguous is called for further jit optimizations.
56   auto input_contig = input.expect_contiguous();
57   const auto* input_ptr =
58       reinterpret_cast<uint8_t*>(input_contig->data_ptr<c10::quint8>());
59 
60   TORCH_CHECK(
61       input.dim() >= 2,
62       "The dimension of input tensor should be larger than or equal to 2");
63   // C(output) = A(input) x B(weight), where C, A, B are M x N, M x K, K x N
64   // matrices, respectively.
65   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
66   int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
67 
68   auto packB = w.get();
69 
70   int64_t N = static_cast<int64_t>(packB->numCols());
71   int64_t K = input.sizes()[input.dim() - 1];
72   TORCH_CHECK(
73       K == static_cast<int64_t>(packB->numRows()),
74       "The number of rows in the packB should be equal to K: " +
75           std::to_string(K));
76 
77   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
78   float input_scale_float = input.q_scale();
79   int32_t input_zero_point_int32 = input.q_zero_point();
80 
81   std::vector<float> output_multiplier_float(1, 0.0);
82   std::vector<float> act_times_w_scale(1, 0.0);
83   TORCH_CHECK(
84       w_scale.size() == w_zp.size(),
85       "Weight scales and zero points vectors should have the same size.");
86   if (q_scheme == c10::kPerTensorAffine) {
87     // Process the per tensor quantization.
88     act_times_w_scale[0] = (input_scale_float * w_scale[0]);
89     output_multiplier_float[0] =
90         act_times_w_scale[0] / static_cast<float>(output_scale);
91   } else if (q_scheme == c10::kPerChannelAffine) {
92     // Process the per channel quantization.
93     output_multiplier_float.resize(N, 0.0);
94     act_times_w_scale.resize(N, 1.0f);
95     for (const auto i : c10::irange(N)) {
96       act_times_w_scale[i] = (input_scale_float * w_scale[i]);
97       output_multiplier_float[i] =
98           act_times_w_scale[i] / static_cast<float>(output_scale);
99     }
100   }
101   int32_t output_zero_point_int32 = static_cast<int32_t>(output_zero_point);
102 
103   const float* bias_ptr = nullptr;
104   c10::MaybeOwned<at::Tensor> bias_contig;
105   if (this->bias_.has_value()) {
106     auto& bias = this->bias_.value();
107     bias_contig = bias.expect_contiguous();
108     TORCH_CHECK(bias_contig->dim() == 1, "bias should be a vector (1D Tensor)");
109     TORCH_CHECK(
110         bias_contig->sizes()[0] == N, "bias should have N elements: " + std::to_string(N));
111     bias_ptr = reinterpret_cast<float*>(bias_contig->data_ptr<float>());
112   }
113 
114   // The resulting matrix here is 2-D, let's view it with the original
115   // left hand dimensions of the input. Here are two examples:
116   // 1. If the input tensor is {M, K}, the output tensor is {M, N}.
117   // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
118   at::DimVector out_sizes(input.sizes());
119   out_sizes.back() = N;
120   // Resize output Tensor
121   output.resize_(out_sizes);
122 
123   // Allocate a buffer for fbgemmPacked to use
124   auto buffer = at::empty(out_sizes, output.options().dtype(at::kInt));
125 
126   auto output_data = reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>());
127 
128   int num_tasks = at::get_num_threads();
129   at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
130     for (const auto task_id : c10::irange(begin, end)) {
131       // This operation does the following:
132       // 1) Creates a "row buffer" vector with offset values that must be
133       //    added to the integer matrix multiplication operation to ensure
134       //    correctness. This "row buffer" is also called the row offset, and
135       //    it is needed when we use affine quantization for weights.
136       // 2) Packs the resulting quantized matrix into vector-register and
137       //    cache friendly tiles.
138       //
139       //  Note this is not executed eagerly, but rather within the
140       //  fbgemmPacked call below.
141       fbgemm::PackAWithRowOffset<uint8_t> packA(
142           /*trans=*/fbgemm::matrix_op_t::NoTranspose,
143           /*nRow=*/M,
144           /*nCol=*/K,
145           /*smat=*/input_ptr,
146           /*ld=*/K,
147           /*pmat=*/nullptr); // Currently, packA manages ownership of `pmat`.
148                              // TODO: Consider a way to pre-allocate and reuse
149                              // pmat buffer.
150 
151       // ReQuantizeOutput requires pointers to the zero point values,
152       // since in the case of rowwise quantization these will be arrays rather
153       // than scalars. But in this case, we're doing whole-tensor quantization
154       // so we just pass a pointer to the scale values (and internally
155       // ReQuantizeOutput won't index past 0.
156 
157       // This is the end of the pipeline, pass the resulting matrix through.
158       fbgemm::DoNothing<> doNothingObj{};
159 
160       if (q_scheme == c10::kPerTensorAffine) {
161         // Process the per tensor quantization.
162         //
163         // After the uint8 * int8 matrix multiplication is performed, this
164         // operation does:
165         //  1) Add in row and column offsets to the rows and columns,
166         //  respectively.
167         //  2) Add in the bias term.
168         fbgemm::ReQuantizeOutput<
169             ReluFused,
170             fbgemm::QuantizationGranularity::TENSOR,
171             float>
172             outputProcObj(
173                 doNothingObj,
174                 output_multiplier_float.data(),
175                 output_zero_point_int32,
176                 input_zero_point_int32,
177                 w_zp.data(),
178                 packA.getRowOffsetBuffer(),
179                 col_offsets.data(),
180                 bias_ptr,
181                 N, /* nCol */
182                 1 /* groups */,
183                 act_times_w_scale.data());
184 
185         // Do the GEMM
186         fbgemm::fbgemmPacked(
187             /*packA=*/packA,
188             /*packB=*/*packB,
189             /*C=*/output_data,
190             /*C_buffer=*/buffer.data_ptr<int32_t>(),
191             /*ldc=*/N,
192             /*outProcess=*/outputProcObj,
193             /*thread_id=*/task_id,
194             /*num_threads=*/num_tasks);
195       } else if (q_scheme == c10::kPerChannelAffine) {
196         // Process the per channel quantization.
197         //
198         // After the uint8 * int8 matrix multiplication is performed, this
199         // operation does:
200         //  1) Add in row and column offsets to the rows and columns,
201         //  respectively.
202         //  2) Add in the bias term.
203         fbgemm::ReQuantizeOutput<
204             ReluFused,
205             fbgemm::QuantizationGranularity::OUT_CHANNEL,
206             float>
207             outputProcObj(
208                 doNothingObj,
209                 output_multiplier_float.data(),
210                 output_zero_point_int32,
211                 input_zero_point_int32,
212                 w_zp.data(),
213                 packA.getRowOffsetBuffer(),
214                 col_offsets.data(),
215                 bias_ptr,
216                 // NOLINTNEXTLINE(bugprone-argument-comment)
217                 N, /*nCol=*/
218                 1, /* groups*/
219                 act_times_w_scale.data());
220 
221         // Do the GEMM
222         fbgemm::fbgemmPacked(
223             /*packA=*/packA,
224             /*packB=*/*packB,
225             /*C=*/output_data,
226             /*C_buffer=*/buffer.data_ptr<int32_t>(),
227             /*ldc=*/N,
228             /*outProcess=*/outputProcObj,
229             /*thread_id=*/task_id,
230             /*num_threads=*/num_tasks);
231       }
232     }
233   });
234 
235   return output;
236 }
237 
apply(at::Tensor input,double output_scale,int64_t output_zero_point)238 at::Tensor PackedLinearWeight::apply(
239     at::Tensor input,
240     double output_scale,
241     int64_t output_zero_point) {
242   // Allocate output Tensor
243   auto output = at::_empty_affine_quantized(
244       {0},
245       at::device(c10::kCPU).dtype(c10::kQUInt8),
246       output_scale,
247       output_zero_point);
248   apply_impl<false>(input, output_scale, output_zero_point, output);
249   return output;
250 }
251 
apply_relu(at::Tensor input,double output_scale,int64_t output_zero_point)252 at::Tensor PackedLinearWeight::apply_relu(
253     at::Tensor input,
254     double output_scale,
255     int64_t output_zero_point) {
256   auto output = at::_empty_affine_quantized(
257       {0},
258       at::device(c10::kCPU).dtype(c10::kQUInt8),
259       output_scale,
260       output_zero_point);
261   apply_impl<true>(input, output_scale, output_zero_point, output);
262   return output;
263 }
264 
apply_out(const at::Tensor & input,double output_scale,int64_t output_zero_point,at::Tensor & output)265 at::Tensor& PackedLinearWeight::apply_out(
266     const at::Tensor& input,
267     double output_scale,
268     int64_t output_zero_point,
269     at::Tensor& output) {
270   TORCH_CHECK(
271       (output.device() == c10::kCPU) && (output.dtype() == c10::kQUInt8) &&
272       (output.q_scale() == output_scale) &&
273       (output.q_zero_point() == output_zero_point));
274   return apply_impl<false>(input, output_scale, output_zero_point, output);
275 }
276 
apply_relu_out(const at::Tensor & input,double output_scale,int64_t output_zero_point,at::Tensor & output)277 at::Tensor& PackedLinearWeight::apply_relu_out(
278     const at::Tensor& input,
279     double output_scale,
280     int64_t output_zero_point,
281     at::Tensor& output) {
282   TORCH_CHECK(
283       (output.device() == c10::kCPU) && (output.dtype() == c10::kQUInt8) &&
284       (output.q_scale() == output_scale) &&
285       (output.q_zero_point() == output_zero_point));
286   return apply_impl<true>(input, output_scale, output_zero_point, output);
287 }
288 
apply_with_input_q_dq_qweight_dq_output_fp32(at::Tensor input,double input_scale,int64_t input_zero_point)289 at::Tensor PackedLinearWeight::apply_with_input_q_dq_qweight_dq_output_fp32(
290   at::Tensor input,
291   double input_scale,
292   int64_t input_zero_point) {
293   TORCH_CHECK(!input.is_quantized(), "Input tensor for apply_with_input_q_dq_qweight_dq_output_fp32 is quantized; "
294   "Expected input tensor in PackedLinearWeight::apply_with_input_q_dq_qweight_dq_output_fp32 to be full precision.");
295 
296   return apply_with_input_q_dq_qweight_dq_output_fp32_impl<false>(input, input_scale, input_zero_point);
297 }
298 
apply_with_input_q_dq_qweight_dq_relu_output_fp32(at::Tensor input,double input_scale,int64_t input_zero_point)299 at::Tensor PackedLinearWeight::apply_with_input_q_dq_qweight_dq_relu_output_fp32(
300   at::Tensor input,
301   double input_scale,
302   int64_t input_zero_point) {
303   TORCH_CHECK(!input.is_quantized(), "Input tensor for apply_with_input_q_dq_qweight_dq_output_fp32 is quantized; "
304   "Expected input tensor in PackedLinearWeight::apply_with_input_q_dq_qweight_dq_output_fp32 to be full precision.");
305 
306   return apply_with_input_q_dq_qweight_dq_output_fp32_impl<true>(input, input_scale, input_zero_point);
307 }
308 
309 
310 template <bool ReluFused>
apply_with_input_q_dq_qweight_dq_output_fp32_impl(const at::Tensor & input,double input_scale,int64_t input_zero_point)311 at::Tensor PackedLinearWeight::apply_with_input_q_dq_qweight_dq_output_fp32_impl(
312     const at::Tensor& input,
313     double input_scale,
314     int64_t input_zero_point) {
315   TORCH_CHECK(
316       fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
317 
318   auto input_contig = input.expect_contiguous();
319   const auto* input_ptr = input_contig->const_data_ptr<float>();
320 
321   TORCH_CHECK(
322       input.dim() >= 2,
323       "The dimension of input tensor should be larger than or equal to 2");
324   int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
325 
326   auto packB = w.get();
327 
328   int64_t N = static_cast<int64_t>(packB->numCols());
329   int64_t K = input.sizes()[input.dim() - 1];
330   TORCH_CHECK(
331       K == static_cast<int64_t>(packB->numRows()),
332       "The number of rows in the packB should be equal to K: " +
333           std::to_string(K));
334 
335   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
336   float input_scale_float = input_scale;
337   int32_t input_zero_point_int32 = input_zero_point;
338 
339   TORCH_CHECK(
340       w_scale.size() == w_zp.size(),
341       "Weight scales and zero points vectors should have the same size.");
342 
343   const float* bias_ptr = nullptr;
344   c10::MaybeOwned<at::Tensor> bias_contig;
345   if (this->bias_.has_value()) {
346     auto& bias = this->bias_.value();
347     bias_contig = bias.expect_contiguous();
348     TORCH_CHECK(bias_contig->dim() == 1, "bias should be a vector (1D Tensor)");
349     TORCH_CHECK(
350         bias_contig->sizes()[0] == N, "bias should have N elements: " + std::to_string(N));
351     bias_ptr = bias_contig->data_ptr<float>();
352   }
353 
354   std::vector<int64_t> out_sizes = input.sizes().vec();
355   out_sizes.back() = N;
356   // Allocate output Tensor and a buffer for fbgemmPacked to use
357   auto output = at::empty(out_sizes, input.options().dtype(at::kFloat));
358   auto buffer = at::empty_like(
359       output,
360       output.options().dtype(at::kInt),
361       LEGACY_CONTIGUOUS_MEMORY_FORMAT);
362 
363   auto output_data = output.data_ptr<float>();
364 
365   int num_tasks = at::get_num_threads();
366   at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
367     fbgemm::PackAWithQuantRowOffset<uint8_t> packA(
368         /*trans=*/fbgemm::matrix_op_t::NoTranspose,
369         /*nRow=*/M,
370         /*nCol=*/K,
371         /*smat=*/input_ptr,
372         /*ld=*/K,
373         /*pmat=*/nullptr,
374         /*scale=*/input_scale_float,
375         /*zero_pt=*/input_zero_point_int32);
376 
377     fbgemm::DoNothing<float, float> doNothingObj{};
378     for (const auto task_id : c10::irange(begin, end)) {
379       if (q_scheme == c10::kPerTensorAffine) {
380         // Process the per tensor quantization.
381         //
382         // After the uint8 * int8 matrix multiplication is performed, this
383         // operation does:
384         //  1) Add in row and column offsets to the rows and columns,
385         //  respectively.
386         //  2) Add in the bias term.
387         fbgemm::ReQuantizeForFloat<ReluFused>
388             outputProcObj(
389                 doNothingObj,
390                 input_scale_float,
391                 w_scale.data(),
392                 input_zero_point_int32,
393                 w_zp.data(),
394                 packA.getRowOffsetBuffer(),
395                 col_offsets.data(),
396                 bias_ptr,
397                 N /* nCol */);
398 
399         // Do the GEMM
400         fbgemm::fbgemmPacked(
401             /*packA=*/packA,
402             /*packB=*/*packB,
403             /*C=*/output_data,
404             /*C_buffer=*/buffer.data_ptr<int32_t>(),
405             /*ldc=*/N,
406             /*outProcess=*/outputProcObj,
407             /*thread_id=*/task_id,
408             /*num_threads=*/num_tasks);
409       } else if (q_scheme == c10::kPerChannelAffine) {
410         // Process the per channel quantization.
411         //
412         // After the uint8 * int8 matrix multiplication is performed, this
413         // operation does:
414         //  1) Add in row and column offsets to the rows and columns,
415         //  respectively.
416         //  2) Add in the bias term.
417         fbgemm::ReQuantizeForFloat<
418             ReluFused,
419             fbgemm::QuantizationGranularity::OUT_CHANNEL>
420             outputProcObj(
421                 doNothingObj,
422                 input_scale_float,
423                 w_scale.data(),
424                 input_zero_point_int32,
425                 w_zp.data(),
426                 packA.getRowOffsetBuffer(),
427                 col_offsets.data(),
428                 bias_ptr,
429                 N /* nCol */);
430 
431         // Do the GEMM
432         fbgemm::fbgemmPacked(
433             /*packA=*/packA,
434             /*packB=*/*packB,
435             /*C=*/output_data,
436             /*C_buffer=*/buffer.data_ptr<int32_t>(),
437             /*ldc=*/N,
438             /*outProcess=*/outputProcObj,
439             /*thread_id=*/task_id,
440             /*num_threads=*/num_tasks);
441       }
442     }
443   });
444   return output;
445 }
446 
447 #endif // USE_FBGEMM
448 
449 #ifdef USE_PYTORCH_QNNPACK
450 
451 #ifdef USE_XNNPACK
452 // TODO: add per_channel support in the future when xnnp supports it
453 template <typename scalar_t, bool kReluFused>
apply_impl_xnnp(const at::Tensor & input,double output_scale,int64_t output_zero_point)454 at::Tensor PackedLinearWeightsQnnp::apply_impl_xnnp(
455     const at::Tensor& input,
456     double output_scale,
457     int64_t output_zero_point) {
458   using underlying_t = typename scalar_t::underlying;
459 
460   std::lock_guard<std::mutex> lock(qnnp_mutex_);
461 
462   const std::string func_name = kReluFused ? "quantized::linear_relu (xnnpack)"
463                                            : "quantized::linear (xnnpack)";
464   TORCH_CHECK(
465       input.dim() >= 2, func_name, ": Input tensor rank should be >= 2.");
466   TORCH_CHECK(
467       !per_channel(),
468       func_name,
469       ": xnnpack does not currently have per_channel support.");
470 
471   const auto input_contig = input.contiguous();
472   const auto input_scale = input_contig.q_scale();
473 
474   const size_t rows_w = bias_.size(0);
475   const size_t cols_w = input_contig.size(input_contig.dim() - 1);
476 
477   auto status = xnn_status_invalid_state;
478 
479   // Create an operator iff not already created
480   if (!xnnp_linear_op ||
481       (!this->input_scale.has_value() ||
482        this->input_scale.value() != input_scale)) {
483     // Update the input scale so we may cache the op
484     this->input_scale = input_scale;
485 
486     xnn_operator_t xnnp_op = nullptr;
487 
488     const float* weight_scales_data = w_scales.const_data_ptr<float>();
489 
490     // prepare weights
491     underlying_t w_zp = static_cast<underlying_t>(
492         orig_weight.q_zero_point() +
493         (std::is_same<underlying_t, uint8_t>::value ? 128 : 0));
494 
495    at::Tensor xnnp_weight = at::_empty_affine_quantized(
496         orig_weight.sizes(),
497         c10::CppTypeToScalarType<scalar_t>::value,
498         weight_scales_data[0],
499         w_zp);
500 
501     // copy from the original weight and take care of dtype change if necessary
502     at::native::xnnp_utils::q8_copy_int8_weight_and_add_offset<scalar_t>(
503         orig_weight, xnnp_weight);
504 
505     // Original bias was float, so we requantize it here.
506     at::Tensor qbias = quant_utils::QuantizeBias(false, bias_, orig_weight, input_scale);
507 
508     // output limits
509    auto output_min = kReluFused
510         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
511         ? activationLimits<underlying_t>(output_scale, output_zero_point, Activation::RELU).first
512         : std::numeric_limits<underlying_t>::min();
513     auto output_max = kReluFused
514         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
515         ? activationLimits<underlying_t>(output_scale, output_zero_point, Activation::RELU).second
516         : std::numeric_limits<underlying_t>::max();
517 
518     // Create an operator
519     status = at::native::xnnp_utils::xnnp_create_fully_connected_nc(
520         cols_w, /* input_channels */
521         rows_w, /* output_channels */
522         cols_w, /* input_stride */
523         rows_w, /* output_stride */
524         input_contig.q_zero_point(),
525         input_contig.q_scale(),
526         w_zp,
527         weight_scales_data[0],
528         reinterpret_cast<const underlying_t*>(
529             xnnp_weight.template data_ptr<scalar_t>()),
530         reinterpret_cast<int32_t*>(qbias.data_ptr<c10::qint32>()),
531         output_zero_point,
532         output_scale,
533         output_min,
534         output_max,
535         0, /* flags */
536         &xnnp_op);
537     xnnp_linear_op = xnnpack_operator(xnnp_op);
538 
539     TORCH_CHECK(
540         status == xnn_status_success,
541         func_name,
542         ": xnn create operator failed(",
543         status,
544         ")");
545   }
546 
547   /*
548    * Allocate output Tensor and a buffer for XNNPACK to use
549    * The resulting matrix here is 2-D, let's view it with the original
550    * left hand dimensions of the input. Here are two examples:
551    * 1. If the input tensor is {M, K}, the output tensor is {M, N}.
552    * 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
553    */
554   std::vector<int64_t> out_sizes = input.sizes().vec();
555   out_sizes.back() = static_cast<int64_t>(rows_w);
556   at::Tensor output = at::native::empty_affine_quantized(
557       out_sizes,
558       c10::CppTypeToScalarType<scalar_t>::value,
559       std::nullopt /* layout */,
560       c10::kCPU,
561       std::nullopt /* pin_memory */,
562       output_scale,
563       output_zero_point,
564       input.suggest_memory_format());
565 
566   // calculate batch_size
567   size_t rows_input = 1;
568   for (const auto i : c10::irange(input_contig.dim() - 1)) {
569     rows_input *= input_contig.size(i);
570   }
571 
572   // Reshape the operator
573   status = at::native::xnnp_utils::xnnp_reshape_fully_connected_nc(
574       xnnp_linear_op.get(),
575       rows_input, /* batch_size */
576       caffe2::pthreadpool_());
577 
578   // Setup the operator
579   status = at::native::xnnp_utils::xnnp_setup_fully_connected_nc(
580       xnnp_linear_op.get(),
581       reinterpret_cast<const underlying_t*>(
582           input_contig.template data_ptr<scalar_t>()),
583       reinterpret_cast<underlying_t*>(output.template data_ptr<scalar_t>())
584     );
585 
586   TORCH_CHECK(
587       status == xnn_status_success,
588       func_name,
589       ": xnn setup operator failed(",
590       status,
591       ")");
592 
593   // Run the operator
594   status = xnn_run_operator(
595       xnnp_linear_op.get(), // Linear op
596       caffe2::pthreadpool_() // threadpool
597   );
598   TORCH_CHECK(
599       status == xnn_status_success,
600       func_name,
601       ": xnn run operator failed(",
602       status,
603       ")");
604 
605   return output;
606 }
607 #endif // USE_XNNPACK
608 
609 template <bool ReluFused>
apply_impl(at::Tensor input,double output_scale,int64_t output_zero_point)610 at::Tensor PackedLinearWeightsQnnp::apply_impl(
611     at::Tensor input,
612     double output_scale,
613     int64_t output_zero_point) {
614   TORCH_CHECK(
615       input.dim() >= 2,
616       "quantized::linear(): Input tensor rank should be >= 2");
617   TORCH_CHECK(input.scalar_type() == c10::kQUInt8,
618                 "quantized::linear (qnnpack): Expected input data type ",
619                 toString(c10::kQUInt8),
620                 " but got ",
621                 toString(input.scalar_type()));
622 
623   auto input_contig = input.contiguous();
624 
625   // Weight packing is not thread safe
626   std::lock_guard<std::mutex> lock(qnnp_mutex_);
627   auto packB = w.get();
628   size_t rows_w = bias_.size(0);
629   size_t cols_w = input_contig.size(input_contig.dim() - 1);
630   auto input_scale = input_contig.q_scale();
631 
632   if (!this->input_scale.has_value() ||
633       this->input_scale.value() != input_scale) {
634     // Get the original weight and adjust it to uint8 from int8
635     auto weight_contig = orig_weight;
636     auto bias_fp32 = bias_;
637     int8_t* w_data = (int8_t*)weight_contig.data_ptr<c10::qint8>();
638 
639     float* weight_scales_data = w_scales.data_ptr<float>();
640     // We calculate requant scale here as the vector holding the requant scale
641     // is owned by this module. The pointer is then passed to qnnpack backend.
642     generate_requantization_scales(
643         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
644         w_scales, input_scale, output_scale, requantization_scales);
645 
646     at::Tensor qnnp_weight = at::_empty_affine_quantized(
647         weight_contig.sizes(),
648         at::device(c10::kCPU).dtype(c10::kQUInt8),
649         weight_scales_data[0],
650         w_zero_points[0]);
651     auto* qnnp_w_data = qnnp_weight.data_ptr<c10::quint8>();
652     auto wt_numel = weight_contig.numel();
653     for (const auto i : c10::irange(wt_numel)) {
654       qnnp_w_data[i] = static_cast<c10::quint8>(w_data[i] + 128);
655     }
656     // Original bias was float, so we requantize it here.
657     const bool is_per_channel = orig_weight.qscheme() == at::kPerChannelAffine;
658     at::Tensor qbias = quant_utils::QuantizeBias(is_per_channel, bias_fp32, weight_contig, input_scale);
659 
660     // Update the input scale to not pack again.
661     this->input_scale = input_scale;
662     w.reset();
663     w = std::make_unique<qnnpack::PackBMatrix>(
664         cols_w /* input_channels */,
665         rows_w /* output_channels */,
666         w_zero_points.data(),
667         requantization_scales.data(),
668         reinterpret_cast<uint8_t*>(qnnp_w_data),
669         reinterpret_cast<int32_t*>(qbias.data_ptr<c10::qint32>()));
670     packB = w.get();
671     if (at::globalContext().releaseWeightsWhenPrepacking()) {
672       // On mobile, we release the original weight by resetting the intrusive_ptr.
673       // Calling unpack after this will throw an assertion.
674       orig_weight.reset();
675     }
676   }
677 
678   size_t rows_input = 1;
679   size_t cols_input = input_contig.size(input_contig.dim() - 1);
680   for (const auto i : c10::irange(input_contig.dim() -1)) {
681     rows_input *= input_contig.size(i);
682   }
683 
684   TORCH_CHECK(
685       cols_input == cols_w,
686       "quantized::linear(): input size does not match weight dimension 1 size: \
687          got ",
688       cols_input,
689       " but expected ",
690       cols_w);
691 
692   // Allocate output Tensor and a buffer for QNNPACK to use
693   // The resulting matrix here is 2-D, let's view it with the original
694   // left hand dimensions of the input. Here are two examples:
695   // 1. If the input tensor is {M, K}, the output tensor is {M, N}.
696   // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
697   std::vector<int64_t> out_sizes = input.sizes().vec();
698   out_sizes.back() = static_cast<long>(rows_w);
699   at::Tensor output = at::_empty_affine_quantized(
700       out_sizes,
701       input.options(),
702       output_scale,
703       output_zero_point);
704 
705   auto output_min = ReluFused
706       // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
707       ? activationLimits<uint8_t>(output_scale, output_zero_point, Activation::RELU)
708             .first
709       : std::numeric_limits<uint8_t>::min();
710   auto output_max = ReluFused
711       // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
712       ? activationLimits<uint8_t>(output_scale, output_zero_point, Activation::RELU)
713             .second
714       : std::numeric_limits<uint8_t>::max();
715   TORCH_INTERNAL_ASSERT(packB != nullptr, "Packed Weights are NULL");
716   const pytorch_qnnp_status runStatus = qnnpack::qnnpackLinear(
717       rows_input /* batch_size */,
718       cols_input /* input_channels */,
719       rows_w /* output_channels */,
720       input_contig.q_zero_point(),
721       w_zero_points.data(),
722       requantization_scales.data(),
723       output_zero_point,
724       output_min,
725       output_max,
726       (uint8_t*)input_contig.data_ptr<c10::quint8>(),
727       cols_input /* input_stride */,
728       packB->getPackedWeights(),
729       (uint8_t*)output.data_ptr<c10::quint8>(),
730       rows_w /* output_stride */,
731       // TODO (Ashkan): Disabling temporarily.
732       // Throws a floating point exception with OSS pthreadpool.
733       caffe2::pthreadpool_() /* threadpool */);
734 
735   TORCH_INTERNAL_ASSERT(
736       runStatus == pytorch_qnnp_status_success,
737       "failed to run QNNPACK Linear operator");
738 
739   return output;
740 }
741 
742 #ifdef USE_XNNPACK
can_use_xnnp(c10::ScalarType dtype,bool per_channel)743 static bool can_use_xnnp(c10::ScalarType dtype, bool per_channel) {
744   if(!at::native::xnnpack::available()) {
745     return false;
746   }
747 
748   bool supported_dtypes = dtype == c10::kQInt8;
749   bool invalid_config = per_channel; /* xnnp does not currently support
750                                         per-channel fully connected op */
751   if (supported_dtypes && invalid_config) {
752     /* don't want this to fall through to QNNPACK */
753     TORCH_CHECK(
754         false,
755         "quantized::linear (xnnpack): Unsupported config for dtype KQInt8");
756   }
757   return supported_dtypes && !invalid_config;
758 }
759 #endif // USE_XNNPACK
760 
apply(at::Tensor input,double output_scale,int64_t output_zero_point)761 at::Tensor PackedLinearWeightsQnnp::apply(
762     at::Tensor input,
763     double output_scale,
764     int64_t output_zero_point) {
765 #ifdef USE_XNNPACK
766   if (can_use_xnnp(input.scalar_type(), per_channel())) {
767     return apply_impl_xnnp<c10::qint8, false>(
768         input, output_scale, output_zero_point);
769   } /* fall through for unsupported types, configs, or shapes */
770 #endif // USE_XNNPACK
771   return apply_impl<false>(std::move(input), output_scale, output_zero_point);
772 }
773 
apply_relu(at::Tensor input,double output_scale,int64_t output_zero_point)774 at::Tensor PackedLinearWeightsQnnp::apply_relu(
775     at::Tensor input,
776     double output_scale,
777     int64_t output_zero_point) {
778 #ifdef USE_XNNPACK
779   if (can_use_xnnp(input.scalar_type(), per_channel())) {
780     return apply_impl_xnnp<c10::qint8, true>(
781         input, output_scale, output_zero_point);
782   } /* fall through for unsupported types, configs, or shapes */
783 #endif // USE_XNNPACK
784   return apply_impl<true>(std::move(input), output_scale, output_zero_point);
785 }
786 
787 #endif // USE_PYTORCH_QNNPACK
788 
789 #if AT_MKLDNN_ENABLED()
790 template <PostOps post_op>
apply_impl(at::Tensor input,double output_scale,int64_t output_zero_point,torch::List<at::Scalar> post_op_args)791 at::Tensor PackedLinearWeightsOnednn::apply_impl(
792     at::Tensor input,
793     double output_scale,
794     int64_t output_zero_point,
795     torch::List<at::Scalar> post_op_args) {
796   const int64_t dim = input.dim();
797   TORCH_CHECK(
798       dim != 0,
799       "qlinear (ONEDNN): input dim should be at least 1, but got 0");
800   TORCH_CHECK(input.scalar_type() == c10::ScalarType::QUInt8,
801       "qlinear (ONEDNN): data type of input should be QUint8.");
802 
803   auto input_contig = input.expect_contiguous();
804   auto& w = *(weight_.get());
805   auto K = input.size(dim - 1), M = input.numel() / K, N = w.get_dim(1);
806   auto input_dims = {M, K};
807   auto input_data_type = dnnl::memory::data_type::u8;
808   auto input_desc = ideep::tensor::desc(input_dims, input_data_type);
809   ideep::attr_t op_attr = ideep::attr_t();
810   if (post_op == Relu) {
811     op_attr = ideep::attr_t::fuse_relu();
812   } else if (post_op == LeakyRelu) {
813     op_attr = ideep::attr_t::fuse_relu(/*scale=*/1.0f, /*alpha=*/post_op_args.get(0).to<double>());
814   } else if (post_op == Tanh) {
815     op_attr = ideep::attr_t::fuse_tanh();
816   }
817   ideep::tensor x(input_desc, input_contig->data_ptr<c10::quint8>());
818   auto dst_dims = {M, N};
819   double input_scale = input.q_scale();
820   int64_t input_zero_point = input.q_zero_point();
821   const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0/input_scale);
822   const ideep::scale_t& weights_scales = w.get_scale();
823   // Scales of ONEDNN and PyTorch are reciprocal
824   const ideep::scale_t& dst_scales = ideep::scale_t(1, 1.0/output_scale);
825   const ideep::zero_point_t& src_zero_point = ideep::zero_point_t(1, input_zero_point);
826   const ideep::zero_point_t& dst_zero_point = ideep::zero_point_t(1, output_zero_point);
827   // Compute: Use ideep::matmul_forward to support asymmetric quantization
828   // Allocate output Tensor
829   at::Tensor output = at::_empty_affine_quantized(
830       dst_dims,
831       at::device(c10::kCPU).dtype(c10::kQUInt8),
832       output_scale,
833       output_zero_point);
834   if (output.numel() == 0) {
835     return output;
836   }
837   ideep::tensor y({dst_dims, ideep::tensor::data_type::u8,
838                    {output.strides().cbegin(), output.strides().cend()}},
839                   output.data_ptr());
840   bool with_bias = bias_.has_value();
841   if (with_bias) {
842     // Bias might be modified outside (e.g. by quantization bias correction).
843     // If so, update the prepacked bias as well.
844     if (bias_.value().get_data_handle() != orig_bias_.value().data_ptr()) {
845       bias_.value().init(bias_.value().get_desc(), orig_bias_.value().data_ptr());
846     }
847   }
848   const auto& b = with_bias ? bias_.value() : ideep::tensor();
849   // Primitive cache is initialized when called for the first time
850   // and won't be updated afterwards.
851   int num_threads = at::get_num_threads();
852   PrimitiveCacheKey cache_key = std::make_tuple(
853       input_scale, input_zero_point, input_dims, output_scale, output_zero_point, num_threads, /*accum scale*/1.0, /*accum zero point*/0);
854   c10::call_once(*cache_initialized_flag, [&](){
855       LinearParams params;
856       ideep::matmul_forward::prepare</*is_dynamic=*/false>(
857           params, x, w, b, y,
858           src_scales, weights_scales, dst_scales,
859           src_zero_point, dst_zero_point, 1.0f, 1.0f, op_attr);
860       get_cache() = LinearPrimitiveCache(cache_key, params);
861       w = w.reorder_if_differ_in(params.pd.weights_desc());
862   });
863   if (get_cache().hit(cache_key)) {
864     LinearParams& params = get_cache().get_param();
865     ideep::matmul_forward::compute<false, false>(params, x, w, b, y);
866   } else {
867     ideep::matmul_forward::compute(x, w, b, y, src_scales, weights_scales,
868                                    dst_scales, src_zero_point, dst_zero_point,
869                                    1.0f, 1.0f, op_attr);
870   }
871   auto out_sizes = input.sizes().vec();
872   out_sizes.back() = N;
873   if (output.sizes().vec() == out_sizes)
874     return output;
875   return output.reshape(out_sizes);
876 }
877 
apply(at::Tensor input,double output_scale,int64_t output_zero_point)878 at::Tensor PackedLinearWeightsOnednn::apply(
879     at::Tensor input,
880     double output_scale,
881     int64_t output_zero_point) {
882   return apply_impl<NoPostOp>(
883       std::move(input), output_scale, output_zero_point);
884 }
885 
apply_relu(at::Tensor input,double output_scale,int64_t output_zero_point)886 at::Tensor PackedLinearWeightsOnednn::apply_relu(
887     at::Tensor input,
888     double output_scale,
889     int64_t output_zero_point) {
890   return apply_impl<Relu>(
891       std::move(input), output_scale, output_zero_point);
892 }
893 
apply_leaky_relu(at::Tensor input,double output_scale,int64_t output_zero_point,double negative_slope)894 at::Tensor PackedLinearWeightsOnednn:: apply_leaky_relu(
895     at::Tensor input,
896     double output_scale,
897     int64_t output_zero_point,
898     double negative_slope) {
899   torch::List<at::Scalar> post_op_args =
900       {at::Scalar(negative_slope)};
901   return apply_impl<LeakyRelu>(
902       std::move(input), output_scale, output_zero_point, post_op_args);
903 }
904 
apply_tanh(at::Tensor input,double output_scale,int64_t output_zero_point)905 at::Tensor PackedLinearWeightsOnednn:: apply_tanh(
906     at::Tensor input,
907     double output_scale,
908     int64_t output_zero_point) {
909   return apply_impl<Tanh>(
910       std::move(input), output_scale, output_zero_point);
911 }
912 
linear_int8_with_onednn_weight(at::Tensor input,double input_scale,int64_t input_zero_point,at::Tensor onednn_weight,at::Tensor weight_scales,at::Tensor weight_zero_points,std::optional<at::Tensor> bias,double output_scale,int64_t output_zero_point,std::optional<c10::ScalarType> output_dtype,std::optional<at::Tensor> other,double other_scale,int64_t other_zero_point,const c10::string_view & binary_post_op,double binary_alpha,const c10::string_view & unary_post_op,torch::List<std::optional<at::Scalar>> & unary_post_op_args,c10::string_view & unary_post_op_algorithm)913 static at::Tensor linear_int8_with_onednn_weight(
914     at::Tensor input, // int8 CPU Tensor, not QTensor
915     double input_scale,
916     int64_t input_zero_point,
917     at::Tensor onednn_weight, // int8 tensor from MkldnnCPU
918     at::Tensor weight_scales,
919     at::Tensor weight_zero_points,
920     std::optional<at::Tensor> bias, // plain tensor
921     double output_scale,
922     int64_t output_zero_point,
923     std::optional<c10::ScalarType> output_dtype,
924     std::optional<at::Tensor> other, // extra input for binary post-op
925     double other_scale,
926     int64_t other_zero_point,
927     const c10::string_view& binary_post_op, // e.g. "none", "sum", "add"
928     double binary_alpha,
929     const c10::string_view& unary_post_op, // e.g. "none", "relu"
930     torch::List<std::optional<at::Scalar>>& unary_post_op_args,
931     c10::string_view& unary_post_op_algorithm) {
932   using ideep::tensor;
933   const int64_t dim = input.dim();
934   TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte,
935       "qlinear with mkldnn tensor: data type of input should be uint8 (unsigned char).");
936   TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char,
937       "qlinear with mkldnn tensor: data type of weight should be int8 (char).");
938   TORCH_CHECK(
939       weight_scales.scalar_type() == c10::ScalarType::Float, "weight scales should be dtype c10::ScalarType::Float.");
940   TORCH_CHECK(
941       binary_alpha == 1.0f, "onednn qlinear: alpha != 1 for binary post op is not yet supported.");
942   bool fp32_output = output_dtype.has_value() && (output_dtype.value() == c10::kFloat);
943   bool bf16_output = output_dtype.has_value() && (output_dtype.value() == c10::kBFloat16);
944   if (fp32_output || bf16_output) {
945     TORCH_CHECK(
946         output_scale == 1.0f && output_zero_point == 0, "onednn qlinear: expect scale=1 and zero point=0 for fp32 output");
947   }
948   if (binary_post_op != "none") {
949     /* Supported cases for binary post op:
950       +-------------------+--------------+---------------+
951       | Extra input dtype | Output dtype | Post op       |
952       +-------------------+--------------+---------------+
953       | Fp32/bf16         | fp32/bf16    | sum           |
954       +-------------------+--------------+---------------+
955       | Fp32/bf16         | int8         | add           |
956       +-------------------+--------------+---------------+
957       | int8              | fp32/bf16    | not supported |
958       +-------------------+--------------+---------------+
959       | int8              | int8         | sum           |
960       +-------------------+--------------+---------------+
961     */
962     TORCH_CHECK(other.has_value(), "onednn qlinear: the extra input is missing for post op ", binary_post_op);
963     if (fp32_output || bf16_output) {
964       TORCH_CHECK(
965           other_scale == 1.0f && other_zero_point == 0,
966           "onednn qlinear: expect extra input scale = 1.0 and zero point = 0 when output dtype is ", output_dtype.value(),
967           ", but got ", other_scale, " and ", other_zero_point, ", respectively"
968       );
969     }
970     if (binary_post_op == "sum") {
971       auto expected_dtype = output_dtype.has_value() ? output_dtype.value() : c10::kByte;
972       TORCH_CHECK(
973           other.value().scalar_type() == expected_dtype,
974           "onednn qlinear: the dtype of extra input for binary post op should be ", expected_dtype,
975           " (same as output dtype), but got ", other.value().scalar_type()
976       );
977     }
978   }
979 
980   // If the input has more than two dimensions, we will reshape it to a 2-dimensional form
981   // for calculation and subsequently reshape the output back.
982   auto input_contig =
983       dim == 2 ? input.contiguous() : input.reshape({-1, input.size(dim - 1)}).contiguous();
984 
985   auto src = at::native::itensor_from_tensor(input_contig);
986   auto packed_weight = at::native::itensor_from_mkldnn(onednn_weight);
987   int64_t K = input.size(dim - 1), M = input.numel() / K, N = packed_weight.get_dim(1);
988 
989   auto output_size = input.sizes().vec();
990   output_size[dim - 1] = N;
991 
992   std::optional<ideep::tensor> onednn_bias{std::nullopt};
993   bool with_bias = bias.has_value();
994   at::Tensor bias_val_float;
995   if (with_bias) {
996     bias_val_float = bias.value().to(at::kFloat);
997     if (bias_val_float.dim() == 1) {
998       auto b_reshape = bias_val_float.reshape({1, bias_val_float.size(0)});
999       onednn_bias = at::native::itensor_view_from_dense(b_reshape);
1000     } else {
1001       onednn_bias = at::native::itensor_view_from_dense(bias_val_float);
1002     }
1003   }
1004   std::vector<int64_t> src_dims = {M, K};
1005   std::vector<int64_t> dst_dims = {M, N};
1006   at::Tensor output = binary_post_op == "sum" ?
1007       other.value() :
1008       at::empty(
1009         dst_dims,
1010         device(c10::kCPU)
1011             .dtype(fp32_output ? c10::kFloat : (bf16_output ? c10::kBFloat16 : c10::kByte))
1012       );
1013   if (output.numel() == 0) {
1014     return output;
1015   }
1016   tensor dst = at::native::itensor_view_from_dense(output);
1017   static tensor empty_tensor;
1018   static tensor::desc empty_tensor_desc;
1019   tensor src1 = binary_post_op == "add" ?
1020       at::native::itensor_view_from_dense(other.value().reshape({-1, other.value().size(dim - 1)})) :
1021       empty_tensor;
1022 
1023   // Create onednn primitive
1024   auto src_desc = tensor::desc(src_dims, ideep::data_type::u8, ideep::format_tag::any);
1025   auto weights_desc = packed_weight.get_desc();
1026   auto dst_dtype = dst.get_data_type();
1027   auto dst_desc = tensor::desc(dst_dims, dst_dtype, ideep::format_tag::any);
1028   auto bias_desc = with_bias ?
1029       tensor::desc(onednn_bias.value().get_dims(), ideep::data_type::f32, ideep::format_tag::any) :
1030       empty_tensor_desc;
1031   // Get op attr for primitive
1032   // Note: output_scale & output_zero_point are for re-quantization of the final output.
1033   // And other_scale & other_zero_point are for dequantization of other.
1034   auto other_desc = binary_post_op == "add" ? src1.get_desc() : empty_tensor_desc;
1035   auto op_attr = onednn_utils::create_attr_by_post_op(
1036     binary_post_op,
1037     binary_alpha,
1038     other_scale,
1039     other_zero_point,
1040     other_desc,
1041     unary_post_op,
1042     unary_post_op_args,
1043     unary_post_op_algorithm
1044   );
1045   if (input_scale != 1.0f) {
1046     op_attr.set_scales_mask(DNNL_ARG_SRC, 0);
1047   }
1048   if (input_zero_point != 0) {
1049     op_attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
1050   }
1051   op_attr.set_scales_mask(DNNL_ARG_WEIGHTS, ideep::utils::op_scale_mask(weight_scales.numel()));
1052   if (output_scale != 1.0f) {
1053     op_attr.set_scales_mask(DNNL_ARG_DST, 0);
1054   }
1055   if (output_zero_point != 0) {
1056     op_attr.set_zero_points_mask(DNNL_ARG_DST, 0);
1057   }
1058   op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
1059   auto engine = ideep::engine::cpu_engine();
1060   auto primitive_desc = with_bias ?
1061       dnnl::matmul::primitive_desc(engine, src_desc, weights_desc, bias_desc, dst_desc, op_attr) :
1062       dnnl::matmul::primitive_desc(engine, src_desc, weights_desc, dst_desc, op_attr);
1063   auto primitive = dnnl::matmul(primitive_desc);
1064 
1065   // Reorder weight if needed
1066   auto expected_weight = packed_weight.reorder_if_differ_in(primitive_desc.weights_desc());
1067 
1068   // Prepare args and execute primitive
1069   tensor scratchpad(primitive_desc.scratchpad_desc());
1070   ideep::exec_args args;
1071   args.insert({DNNL_ARG_SRC, src});
1072   args.insert({DNNL_ARG_WEIGHTS, expected_weight});
1073   args.insert({DNNL_ARG_DST, dst});
1074   args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
1075   if (with_bias) {
1076     args.insert({DNNL_ARG_BIAS, onednn_bias.value()});
1077   }
1078   tensor src_scales_t = tensor(ideep::scale_t(1, input_scale));
1079   tensor wei_scales_t = at::native::itensor_from_tensor(weight_scales);
1080   tensor dst_scales_t = tensor(ideep::scale_t(1, output_scale));
1081   tensor src_zp_t = tensor(ideep::zero_point_t(1, input_zero_point));
1082   tensor dst_zp_t = tensor(ideep::zero_point_t(1, output_zero_point));
1083   if (input_scale != 1.0f) {
1084     args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_t});
1085   }
1086   if (output_scale != 1.0f) {
1087     args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_t});
1088   }
1089   args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_t});
1090   if (input_zero_point != 0) {
1091     args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_t});
1092   }
1093   if (output_zero_point != 0) {
1094     args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_t});
1095   }
1096   if (binary_post_op == "add") {
1097     args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, src1});
1098   }
1099   primitive.execute(ideep::stream::default_stream(), args);
1100   return dim == 2 ? output : output.reshape(output_size);
1101 }
1102 #endif // #if AT_MKLDNN_ENABLED()
1103 
1104 namespace at {
1105 namespace native {
1106 namespace {
1107 
1108 template <bool ReluFused>
1109 class QLinearInt8 final {
1110  public:
run(at::Tensor input,const c10::intrusive_ptr<LinearPackedParamsBase> & packed_weight,double output_scale,int64_t output_zero_point)1111   static at::Tensor run(
1112       at::Tensor input,
1113       const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
1114       double output_scale,
1115       int64_t output_zero_point) {
1116     if (ReluFused) {
1117       return packed_weight->apply_relu(
1118           std::move(input), output_scale, output_zero_point);
1119     } else {
1120       return packed_weight->apply(
1121           std::move(input), output_scale, output_zero_point);
1122     }
1123   }
1124 };
1125 
1126 class QLinearLeakyReluInt8 final {
1127  public:
run(at::Tensor input,const c10::intrusive_ptr<LinearPackedParamsBase> & packed_weight,double output_scale,int64_t output_zero_point,double negative_slope)1128   static at::Tensor run(
1129       at::Tensor input,
1130       const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
1131       double output_scale,
1132       int64_t output_zero_point,
1133       double negative_slope) {
1134 #if AT_MKLDNN_ENABLED() || !defined(STRIP_ERROR_MESSAGES)
1135     auto& ctx = at::globalContext();
1136 #endif
1137 #if AT_MKLDNN_ENABLED()
1138     if (ctx.qEngine() == at::QEngine::ONEDNN) {
1139       return dynamic_cast<PackedLinearWeightsOnednn*>(packed_weight.get())->apply_leaky_relu(
1140           std::move(input), output_scale, output_zero_point, negative_slope);
1141     }
1142 #endif
1143     TORCH_CHECK(
1144         false,
1145         "Didn't find engine for operation quantized::linear_leaky_relu ",
1146         toString(ctx.qEngine()));
1147   }
1148 };
1149 
1150 
1151 class QLinearTanhInt8 final {
1152  public:
run(at::Tensor input,const c10::intrusive_ptr<LinearPackedParamsBase> & packed_weight,double output_scale,int64_t output_zero_point)1153   static at::Tensor run(
1154       at::Tensor input,
1155       const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
1156       double output_scale,
1157       int64_t output_zero_point) {
1158 #if AT_MKLDNN_ENABLED() || !defined(STRIP_ERROR_MESSAGES)
1159     auto& ctx = at::globalContext();
1160 #endif
1161 #if AT_MKLDNN_ENABLED()
1162     if (ctx.qEngine() == at::QEngine::ONEDNN) {
1163       return dynamic_cast<PackedLinearWeightsOnednn*>(packed_weight.get())->apply_tanh(
1164           std::move(input), output_scale, output_zero_point);
1165     }
1166 #endif
1167     TORCH_CHECK(
1168         false,
1169         "Didn't find engine for operation quantized::linear_tanh ",
1170         toString(ctx.qEngine()));
1171   }
1172 };
1173 
1174 template <bool ReluFused>
1175 class QLinearInt8FusedQDQ final {
1176  public:
run(at::Tensor input,double input_scale,int64_t input_zero_point,const c10::intrusive_ptr<LinearPackedParamsBase> & packed_weight)1177   static at::Tensor run(
1178       at::Tensor input,
1179       double input_scale,
1180       int64_t input_zero_point,
1181       const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight) {
1182     if (ReluFused) {
1183       return packed_weight->apply_with_input_q_dq_qweight_dq_relu_output_fp32(
1184           std::move(input), input_scale, input_zero_point);
1185     } else {
1186       return packed_weight->apply_with_input_q_dq_qweight_dq_output_fp32(
1187           std::move(input), input_scale, input_zero_point);
1188     }
1189   }
1190 };
1191 
1192 class QLinearOnednn final {
1193  public:
run_pointwise(Tensor act,double act_scale,int64_t act_zero_point,Tensor onednn_weight,Tensor weight_scales,Tensor weight_zero_points,std::optional<Tensor> bias,double output_scale,int64_t output_zero_point,std::optional<c10::ScalarType> output_dtype,c10::string_view post_op_name,torch::List<std::optional<at::Scalar>> post_op_args,c10::string_view post_op_algorithm)1194   static Tensor run_pointwise(
1195       Tensor act, // int8 CPU tensor, not QTensor
1196       double act_scale,
1197       int64_t act_zero_point,
1198       Tensor onednn_weight, // int8 tensor from MkldnnCPU
1199       Tensor weight_scales,
1200       Tensor weight_zero_points,
1201       std::optional<Tensor> bias,
1202       double output_scale,
1203       int64_t output_zero_point,
1204       std::optional<c10::ScalarType> output_dtype,
1205       c10::string_view post_op_name,
1206       torch::List<std::optional<at::Scalar>> post_op_args,
1207       c10::string_view post_op_algorithm) {
1208 #if AT_MKLDNN_ENABLED()
1209     static std::optional<at::Tensor> other = std::nullopt;
1210     static const c10::string_view binary_post_op = "none";
1211     return linear_int8_with_onednn_weight(
1212         act, act_scale, act_zero_point,
1213         onednn_weight, weight_scales, weight_zero_points,
1214         bias, output_scale, output_zero_point, output_dtype,
1215         other, /*other scale*/1.0, /*other zp*/0,
1216         binary_post_op, /*binary alpha*/1.0,
1217         post_op_name, post_op_args, post_op_algorithm
1218     );
1219 #endif
1220     TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)");
1221   }
1222 
run_pointwise_tensor(Tensor act,Tensor act_scale,Tensor act_zero_point,Tensor onednn_weight,Tensor weight_scales,Tensor weight_zero_points,std::optional<Tensor> bias,double output_scale,int64_t output_zero_point,std::optional<c10::ScalarType> output_dtype,c10::string_view post_op_name,torch::List<std::optional<at::Scalar>> post_op_args,c10::string_view post_op_algorithm)1223   static Tensor run_pointwise_tensor(
1224       Tensor act, // int8 CPU tensor, not QTensor
1225       Tensor act_scale,
1226       Tensor act_zero_point,
1227       Tensor onednn_weight, // int8 tensor from MkldnnCPU
1228       Tensor weight_scales,
1229       Tensor weight_zero_points,
1230       std::optional<Tensor> bias,
1231       double output_scale,
1232       int64_t output_zero_point,
1233       std::optional<c10::ScalarType> output_dtype,
1234       c10::string_view post_op_name,
1235       torch::List<std::optional<at::Scalar>> post_op_args,
1236       c10::string_view post_op_algorithm) {
1237 #if AT_MKLDNN_ENABLED()
1238     TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
1239         "onednn int8 linear: act scale/zp size should be 1");
1240     static std::optional<at::Tensor> other = std::nullopt;
1241     static const c10::string_view binary_post_op = "none";
1242     return linear_int8_with_onednn_weight(
1243         act, act_scale.item().toDouble(), act_zero_point.item().toLong(),
1244         onednn_weight, weight_scales, weight_zero_points,
1245         bias, output_scale, output_zero_point, output_dtype,
1246         other, /*other scale*/1.0, /*other zp*/0,
1247         binary_post_op, /*binary alpha*/1.0,
1248         post_op_name, post_op_args, post_op_algorithm
1249     );
1250 #endif
1251     TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)");
1252   }
1253 
run_pointwise_binary(Tensor act,double act_scale,int64_t act_zero_point,Tensor onednn_weight,Tensor weight_scales,Tensor weight_zero_points,std::optional<at::Tensor> other,std::optional<Tensor> bias,double output_scale,int64_t output_zero_point,std::optional<c10::ScalarType> output_dtype,double other_scale,int64_t other_zero_point,c10::string_view binary_post_op,double binary_alpha,c10::string_view unary_post_op,torch::List<std::optional<at::Scalar>> unary_post_op_args,c10::string_view unary_post_op_algorithm)1254   static Tensor run_pointwise_binary(
1255       Tensor act, // int8 CPU tensor, not QTensor
1256       double act_scale,
1257       int64_t act_zero_point,
1258       Tensor onednn_weight, // int8 tensor from MkldnnCPU
1259       Tensor weight_scales,
1260       Tensor weight_zero_points,
1261       std::optional<at::Tensor> other, // extra input for binary post-op
1262       std::optional<Tensor> bias,
1263       double output_scale,
1264       int64_t output_zero_point,
1265       std::optional<c10::ScalarType> output_dtype,
1266       double other_scale,
1267       int64_t other_zero_point,
1268       c10::string_view binary_post_op, // e.g. "none", "sum", "add"
1269       double binary_alpha,
1270       c10::string_view unary_post_op, // e.g. "none", "relu"
1271       torch::List<std::optional<at::Scalar>> unary_post_op_args,
1272       c10::string_view unary_post_op_algorithm) {
1273 #if AT_MKLDNN_ENABLED()
1274     return linear_int8_with_onednn_weight(
1275         act, act_scale, act_zero_point,
1276         onednn_weight, weight_scales, weight_zero_points,
1277         bias, output_scale, output_zero_point, output_dtype,
1278         other, other_scale, other_zero_point,
1279         binary_post_op, binary_alpha,
1280         unary_post_op, unary_post_op_args, unary_post_op_algorithm
1281     );
1282 #endif
1283     TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)");
1284   }
1285 
run_pointwise_binary_tensor(Tensor act,Tensor act_scale,Tensor act_zero_point,Tensor onednn_weight,Tensor weight_scales,Tensor weight_zero_points,std::optional<at::Tensor> other,std::optional<Tensor> bias,double output_scale,int64_t output_zero_point,std::optional<c10::ScalarType> output_dtype,double other_scale,int64_t other_zero_point,c10::string_view binary_post_op,double binary_alpha,c10::string_view unary_post_op,torch::List<std::optional<at::Scalar>> unary_post_op_args,c10::string_view unary_post_op_algorithm)1286   static Tensor run_pointwise_binary_tensor(
1287       Tensor act, // int8 CPU tensor, not QTensor
1288       Tensor act_scale,
1289       Tensor act_zero_point,
1290       Tensor onednn_weight, // int8 tensor from MkldnnCPU
1291       Tensor weight_scales,
1292       Tensor weight_zero_points,
1293       std::optional<at::Tensor> other, // extra input for binary post-op
1294       std::optional<Tensor> bias,
1295       double output_scale,
1296       int64_t output_zero_point,
1297       std::optional<c10::ScalarType> output_dtype,
1298       double other_scale,
1299       int64_t other_zero_point,
1300       c10::string_view binary_post_op, // e.g. "none", "sum", "add"
1301       double binary_alpha,
1302       c10::string_view unary_post_op, // e.g. "none", "relu"
1303       torch::List<std::optional<at::Scalar>> unary_post_op_args,
1304       c10::string_view unary_post_op_algorithm) {
1305 #if AT_MKLDNN_ENABLED()
1306     TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
1307         "onednn int8 linear: act scale/zp size should be 1");
1308     return linear_int8_with_onednn_weight(
1309         act, act_scale.item().toDouble(), act_zero_point.item().toLong(),
1310         onednn_weight, weight_scales, weight_zero_points,
1311         bias, output_scale, output_zero_point, output_dtype,
1312         other, other_scale, other_zero_point,
1313         binary_post_op, binary_alpha,
1314         unary_post_op, unary_post_op_args, unary_post_op_algorithm
1315     );
1316 #endif
1317     TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)");
1318   }
1319 };
1320 
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)1321 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
1322   register_linear_params();
1323   m.impl(TORCH_SELECTIVE_NAME("quantized::linear"), TORCH_FN(QLinearInt8<false>::run));
1324   m.impl(TORCH_SELECTIVE_NAME("quantized::linear_relu"), TORCH_FN(QLinearInt8<true>::run));
1325   m.impl(TORCH_SELECTIVE_NAME("quantized::linear_leaky_relu"), TORCH_FN(QLinearLeakyReluInt8::run));
1326   m.impl(TORCH_SELECTIVE_NAME("quantized::linear_tanh"), TORCH_FN(QLinearTanhInt8::run));
1327 }
1328 
TORCH_LIBRARY_IMPL(_quantized,QuantizedCPU,m)1329 TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) {
1330   register_linear_params();
1331   m.impl(TORCH_SELECTIVE_NAME("_quantized::linear"), TORCH_FN(QLinearInt8<false>::run));
1332 }
1333 
TORCH_LIBRARY_IMPL(quantized,CPU,m)1334 TORCH_LIBRARY_IMPL(quantized, CPU, m) {
1335   m.impl(TORCH_SELECTIVE_NAME("quantized::linear_with_input_q_dq_qweight_dq_output_fp32"), TORCH_FN(QLinearInt8FusedQDQ<false>::run));
1336   m.impl(TORCH_SELECTIVE_NAME("quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32"), TORCH_FN(QLinearInt8FusedQDQ<true>::run));
1337 }
1338 
TORCH_LIBRARY_IMPL(onednn,MkldnnCPU,m)1339 TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) {
1340   m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise"),
1341       TORCH_FN(QLinearOnednn::run_pointwise));
1342   m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.tensor"),
1343       TORCH_FN(QLinearOnednn::run_pointwise_tensor));
1344   m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary"),
1345       TORCH_FN(QLinearOnednn::run_pointwise_binary));
1346   m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary_tensor"),
1347       TORCH_FN(QLinearOnednn::run_pointwise_binary_tensor));
1348 }
1349 
1350 } // namespace
1351 } // namespace native
1352 } // namespace at
1353