1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Context.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorOperators.h>
6 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
7 #include <ATen/native/quantized/PackedParams.h>
8 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
9 #include <ATen/native/quantized/cpu/XnnpackUtils.h>
10 #include <ATen/native/quantized/cpu/OnednnUtils.h>
11 #include <ATen/native/quantized/cpu/QuantUtils.h>
12 #include <ATen/native/mkldnn/MKLDNNCommon.h>
13 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
14 #include <torch/library.h>
15
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #include <ATen/NativeFunctions.h>
19 #else
20 #include <ATen/ops/_empty_affine_quantized.h> // for _empty_affine_q...
21 #include <ATen/ops/_empty_affine_quantized_native.h> // for empty_affine_qu...
22 #include <ATen/ops/empty.h> // for empty
23 #include <ATen/ops/quantize_per_channel_native.h> // for quantize_per_ch...
24 #include <ATen/ops/quantize_per_tensor_native.h> // for quantize_per_te...
25 #include <ATen/ops/zeros.h>
26 #endif
27
28 #include <c10/util/irange.h>
29
30 #include <algorithm>
31 #include <string>
32
33 int register_linear_params();
34
35 #ifdef USE_FBGEMM
36 template <bool ReluFused>
apply_impl(const at::Tensor & input,double output_scale,int64_t output_zero_point,at::Tensor & output)37 at::Tensor& PackedLinearWeight::apply_impl(
38 const at::Tensor& input,
39 double output_scale,
40 int64_t output_zero_point,
41 at::Tensor& output) {
42 // uint8 * int8 -> uint8 (no quantization/dequantization)
43
44 // We make a strong guarantee that models using these operators will have
45 // the same numerics across different machines. Therefore, we do not provide
46 // a fallback path and rather fail loudly if we cannot run FBGEMM.
47 TORCH_CHECK(
48 fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
49 TORCH_CHECK(input.scalar_type() == c10::kQUInt8,
50 "Expected input data type ",
51 toString(c10::kQUInt8),
52 " but got ",
53 toString(input.scalar_type()));
54
55 // TODO: contiguous is called for further jit optimizations.
56 auto input_contig = input.expect_contiguous();
57 const auto* input_ptr =
58 reinterpret_cast<uint8_t*>(input_contig->data_ptr<c10::quint8>());
59
60 TORCH_CHECK(
61 input.dim() >= 2,
62 "The dimension of input tensor should be larger than or equal to 2");
63 // C(output) = A(input) x B(weight), where C, A, B are M x N, M x K, K x N
64 // matrices, respectively.
65 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
66 int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
67
68 auto packB = w.get();
69
70 int64_t N = static_cast<int64_t>(packB->numCols());
71 int64_t K = input.sizes()[input.dim() - 1];
72 TORCH_CHECK(
73 K == static_cast<int64_t>(packB->numRows()),
74 "The number of rows in the packB should be equal to K: " +
75 std::to_string(K));
76
77 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
78 float input_scale_float = input.q_scale();
79 int32_t input_zero_point_int32 = input.q_zero_point();
80
81 std::vector<float> output_multiplier_float(1, 0.0);
82 std::vector<float> act_times_w_scale(1, 0.0);
83 TORCH_CHECK(
84 w_scale.size() == w_zp.size(),
85 "Weight scales and zero points vectors should have the same size.");
86 if (q_scheme == c10::kPerTensorAffine) {
87 // Process the per tensor quantization.
88 act_times_w_scale[0] = (input_scale_float * w_scale[0]);
89 output_multiplier_float[0] =
90 act_times_w_scale[0] / static_cast<float>(output_scale);
91 } else if (q_scheme == c10::kPerChannelAffine) {
92 // Process the per channel quantization.
93 output_multiplier_float.resize(N, 0.0);
94 act_times_w_scale.resize(N, 1.0f);
95 for (const auto i : c10::irange(N)) {
96 act_times_w_scale[i] = (input_scale_float * w_scale[i]);
97 output_multiplier_float[i] =
98 act_times_w_scale[i] / static_cast<float>(output_scale);
99 }
100 }
101 int32_t output_zero_point_int32 = static_cast<int32_t>(output_zero_point);
102
103 const float* bias_ptr = nullptr;
104 c10::MaybeOwned<at::Tensor> bias_contig;
105 if (this->bias_.has_value()) {
106 auto& bias = this->bias_.value();
107 bias_contig = bias.expect_contiguous();
108 TORCH_CHECK(bias_contig->dim() == 1, "bias should be a vector (1D Tensor)");
109 TORCH_CHECK(
110 bias_contig->sizes()[0] == N, "bias should have N elements: " + std::to_string(N));
111 bias_ptr = reinterpret_cast<float*>(bias_contig->data_ptr<float>());
112 }
113
114 // The resulting matrix here is 2-D, let's view it with the original
115 // left hand dimensions of the input. Here are two examples:
116 // 1. If the input tensor is {M, K}, the output tensor is {M, N}.
117 // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
118 at::DimVector out_sizes(input.sizes());
119 out_sizes.back() = N;
120 // Resize output Tensor
121 output.resize_(out_sizes);
122
123 // Allocate a buffer for fbgemmPacked to use
124 auto buffer = at::empty(out_sizes, output.options().dtype(at::kInt));
125
126 auto output_data = reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>());
127
128 int num_tasks = at::get_num_threads();
129 at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
130 for (const auto task_id : c10::irange(begin, end)) {
131 // This operation does the following:
132 // 1) Creates a "row buffer" vector with offset values that must be
133 // added to the integer matrix multiplication operation to ensure
134 // correctness. This "row buffer" is also called the row offset, and
135 // it is needed when we use affine quantization for weights.
136 // 2) Packs the resulting quantized matrix into vector-register and
137 // cache friendly tiles.
138 //
139 // Note this is not executed eagerly, but rather within the
140 // fbgemmPacked call below.
141 fbgemm::PackAWithRowOffset<uint8_t> packA(
142 /*trans=*/fbgemm::matrix_op_t::NoTranspose,
143 /*nRow=*/M,
144 /*nCol=*/K,
145 /*smat=*/input_ptr,
146 /*ld=*/K,
147 /*pmat=*/nullptr); // Currently, packA manages ownership of `pmat`.
148 // TODO: Consider a way to pre-allocate and reuse
149 // pmat buffer.
150
151 // ReQuantizeOutput requires pointers to the zero point values,
152 // since in the case of rowwise quantization these will be arrays rather
153 // than scalars. But in this case, we're doing whole-tensor quantization
154 // so we just pass a pointer to the scale values (and internally
155 // ReQuantizeOutput won't index past 0.
156
157 // This is the end of the pipeline, pass the resulting matrix through.
158 fbgemm::DoNothing<> doNothingObj{};
159
160 if (q_scheme == c10::kPerTensorAffine) {
161 // Process the per tensor quantization.
162 //
163 // After the uint8 * int8 matrix multiplication is performed, this
164 // operation does:
165 // 1) Add in row and column offsets to the rows and columns,
166 // respectively.
167 // 2) Add in the bias term.
168 fbgemm::ReQuantizeOutput<
169 ReluFused,
170 fbgemm::QuantizationGranularity::TENSOR,
171 float>
172 outputProcObj(
173 doNothingObj,
174 output_multiplier_float.data(),
175 output_zero_point_int32,
176 input_zero_point_int32,
177 w_zp.data(),
178 packA.getRowOffsetBuffer(),
179 col_offsets.data(),
180 bias_ptr,
181 N, /* nCol */
182 1 /* groups */,
183 act_times_w_scale.data());
184
185 // Do the GEMM
186 fbgemm::fbgemmPacked(
187 /*packA=*/packA,
188 /*packB=*/*packB,
189 /*C=*/output_data,
190 /*C_buffer=*/buffer.data_ptr<int32_t>(),
191 /*ldc=*/N,
192 /*outProcess=*/outputProcObj,
193 /*thread_id=*/task_id,
194 /*num_threads=*/num_tasks);
195 } else if (q_scheme == c10::kPerChannelAffine) {
196 // Process the per channel quantization.
197 //
198 // After the uint8 * int8 matrix multiplication is performed, this
199 // operation does:
200 // 1) Add in row and column offsets to the rows and columns,
201 // respectively.
202 // 2) Add in the bias term.
203 fbgemm::ReQuantizeOutput<
204 ReluFused,
205 fbgemm::QuantizationGranularity::OUT_CHANNEL,
206 float>
207 outputProcObj(
208 doNothingObj,
209 output_multiplier_float.data(),
210 output_zero_point_int32,
211 input_zero_point_int32,
212 w_zp.data(),
213 packA.getRowOffsetBuffer(),
214 col_offsets.data(),
215 bias_ptr,
216 // NOLINTNEXTLINE(bugprone-argument-comment)
217 N, /*nCol=*/
218 1, /* groups*/
219 act_times_w_scale.data());
220
221 // Do the GEMM
222 fbgemm::fbgemmPacked(
223 /*packA=*/packA,
224 /*packB=*/*packB,
225 /*C=*/output_data,
226 /*C_buffer=*/buffer.data_ptr<int32_t>(),
227 /*ldc=*/N,
228 /*outProcess=*/outputProcObj,
229 /*thread_id=*/task_id,
230 /*num_threads=*/num_tasks);
231 }
232 }
233 });
234
235 return output;
236 }
237
apply(at::Tensor input,double output_scale,int64_t output_zero_point)238 at::Tensor PackedLinearWeight::apply(
239 at::Tensor input,
240 double output_scale,
241 int64_t output_zero_point) {
242 // Allocate output Tensor
243 auto output = at::_empty_affine_quantized(
244 {0},
245 at::device(c10::kCPU).dtype(c10::kQUInt8),
246 output_scale,
247 output_zero_point);
248 apply_impl<false>(input, output_scale, output_zero_point, output);
249 return output;
250 }
251
apply_relu(at::Tensor input,double output_scale,int64_t output_zero_point)252 at::Tensor PackedLinearWeight::apply_relu(
253 at::Tensor input,
254 double output_scale,
255 int64_t output_zero_point) {
256 auto output = at::_empty_affine_quantized(
257 {0},
258 at::device(c10::kCPU).dtype(c10::kQUInt8),
259 output_scale,
260 output_zero_point);
261 apply_impl<true>(input, output_scale, output_zero_point, output);
262 return output;
263 }
264
apply_out(const at::Tensor & input,double output_scale,int64_t output_zero_point,at::Tensor & output)265 at::Tensor& PackedLinearWeight::apply_out(
266 const at::Tensor& input,
267 double output_scale,
268 int64_t output_zero_point,
269 at::Tensor& output) {
270 TORCH_CHECK(
271 (output.device() == c10::kCPU) && (output.dtype() == c10::kQUInt8) &&
272 (output.q_scale() == output_scale) &&
273 (output.q_zero_point() == output_zero_point));
274 return apply_impl<false>(input, output_scale, output_zero_point, output);
275 }
276
apply_relu_out(const at::Tensor & input,double output_scale,int64_t output_zero_point,at::Tensor & output)277 at::Tensor& PackedLinearWeight::apply_relu_out(
278 const at::Tensor& input,
279 double output_scale,
280 int64_t output_zero_point,
281 at::Tensor& output) {
282 TORCH_CHECK(
283 (output.device() == c10::kCPU) && (output.dtype() == c10::kQUInt8) &&
284 (output.q_scale() == output_scale) &&
285 (output.q_zero_point() == output_zero_point));
286 return apply_impl<true>(input, output_scale, output_zero_point, output);
287 }
288
apply_with_input_q_dq_qweight_dq_output_fp32(at::Tensor input,double input_scale,int64_t input_zero_point)289 at::Tensor PackedLinearWeight::apply_with_input_q_dq_qweight_dq_output_fp32(
290 at::Tensor input,
291 double input_scale,
292 int64_t input_zero_point) {
293 TORCH_CHECK(!input.is_quantized(), "Input tensor for apply_with_input_q_dq_qweight_dq_output_fp32 is quantized; "
294 "Expected input tensor in PackedLinearWeight::apply_with_input_q_dq_qweight_dq_output_fp32 to be full precision.");
295
296 return apply_with_input_q_dq_qweight_dq_output_fp32_impl<false>(input, input_scale, input_zero_point);
297 }
298
apply_with_input_q_dq_qweight_dq_relu_output_fp32(at::Tensor input,double input_scale,int64_t input_zero_point)299 at::Tensor PackedLinearWeight::apply_with_input_q_dq_qweight_dq_relu_output_fp32(
300 at::Tensor input,
301 double input_scale,
302 int64_t input_zero_point) {
303 TORCH_CHECK(!input.is_quantized(), "Input tensor for apply_with_input_q_dq_qweight_dq_output_fp32 is quantized; "
304 "Expected input tensor in PackedLinearWeight::apply_with_input_q_dq_qweight_dq_output_fp32 to be full precision.");
305
306 return apply_with_input_q_dq_qweight_dq_output_fp32_impl<true>(input, input_scale, input_zero_point);
307 }
308
309
310 template <bool ReluFused>
apply_with_input_q_dq_qweight_dq_output_fp32_impl(const at::Tensor & input,double input_scale,int64_t input_zero_point)311 at::Tensor PackedLinearWeight::apply_with_input_q_dq_qweight_dq_output_fp32_impl(
312 const at::Tensor& input,
313 double input_scale,
314 int64_t input_zero_point) {
315 TORCH_CHECK(
316 fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
317
318 auto input_contig = input.expect_contiguous();
319 const auto* input_ptr = input_contig->const_data_ptr<float>();
320
321 TORCH_CHECK(
322 input.dim() >= 2,
323 "The dimension of input tensor should be larger than or equal to 2");
324 int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
325
326 auto packB = w.get();
327
328 int64_t N = static_cast<int64_t>(packB->numCols());
329 int64_t K = input.sizes()[input.dim() - 1];
330 TORCH_CHECK(
331 K == static_cast<int64_t>(packB->numRows()),
332 "The number of rows in the packB should be equal to K: " +
333 std::to_string(K));
334
335 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
336 float input_scale_float = input_scale;
337 int32_t input_zero_point_int32 = input_zero_point;
338
339 TORCH_CHECK(
340 w_scale.size() == w_zp.size(),
341 "Weight scales and zero points vectors should have the same size.");
342
343 const float* bias_ptr = nullptr;
344 c10::MaybeOwned<at::Tensor> bias_contig;
345 if (this->bias_.has_value()) {
346 auto& bias = this->bias_.value();
347 bias_contig = bias.expect_contiguous();
348 TORCH_CHECK(bias_contig->dim() == 1, "bias should be a vector (1D Tensor)");
349 TORCH_CHECK(
350 bias_contig->sizes()[0] == N, "bias should have N elements: " + std::to_string(N));
351 bias_ptr = bias_contig->data_ptr<float>();
352 }
353
354 std::vector<int64_t> out_sizes = input.sizes().vec();
355 out_sizes.back() = N;
356 // Allocate output Tensor and a buffer for fbgemmPacked to use
357 auto output = at::empty(out_sizes, input.options().dtype(at::kFloat));
358 auto buffer = at::empty_like(
359 output,
360 output.options().dtype(at::kInt),
361 LEGACY_CONTIGUOUS_MEMORY_FORMAT);
362
363 auto output_data = output.data_ptr<float>();
364
365 int num_tasks = at::get_num_threads();
366 at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
367 fbgemm::PackAWithQuantRowOffset<uint8_t> packA(
368 /*trans=*/fbgemm::matrix_op_t::NoTranspose,
369 /*nRow=*/M,
370 /*nCol=*/K,
371 /*smat=*/input_ptr,
372 /*ld=*/K,
373 /*pmat=*/nullptr,
374 /*scale=*/input_scale_float,
375 /*zero_pt=*/input_zero_point_int32);
376
377 fbgemm::DoNothing<float, float> doNothingObj{};
378 for (const auto task_id : c10::irange(begin, end)) {
379 if (q_scheme == c10::kPerTensorAffine) {
380 // Process the per tensor quantization.
381 //
382 // After the uint8 * int8 matrix multiplication is performed, this
383 // operation does:
384 // 1) Add in row and column offsets to the rows and columns,
385 // respectively.
386 // 2) Add in the bias term.
387 fbgemm::ReQuantizeForFloat<ReluFused>
388 outputProcObj(
389 doNothingObj,
390 input_scale_float,
391 w_scale.data(),
392 input_zero_point_int32,
393 w_zp.data(),
394 packA.getRowOffsetBuffer(),
395 col_offsets.data(),
396 bias_ptr,
397 N /* nCol */);
398
399 // Do the GEMM
400 fbgemm::fbgemmPacked(
401 /*packA=*/packA,
402 /*packB=*/*packB,
403 /*C=*/output_data,
404 /*C_buffer=*/buffer.data_ptr<int32_t>(),
405 /*ldc=*/N,
406 /*outProcess=*/outputProcObj,
407 /*thread_id=*/task_id,
408 /*num_threads=*/num_tasks);
409 } else if (q_scheme == c10::kPerChannelAffine) {
410 // Process the per channel quantization.
411 //
412 // After the uint8 * int8 matrix multiplication is performed, this
413 // operation does:
414 // 1) Add in row and column offsets to the rows and columns,
415 // respectively.
416 // 2) Add in the bias term.
417 fbgemm::ReQuantizeForFloat<
418 ReluFused,
419 fbgemm::QuantizationGranularity::OUT_CHANNEL>
420 outputProcObj(
421 doNothingObj,
422 input_scale_float,
423 w_scale.data(),
424 input_zero_point_int32,
425 w_zp.data(),
426 packA.getRowOffsetBuffer(),
427 col_offsets.data(),
428 bias_ptr,
429 N /* nCol */);
430
431 // Do the GEMM
432 fbgemm::fbgemmPacked(
433 /*packA=*/packA,
434 /*packB=*/*packB,
435 /*C=*/output_data,
436 /*C_buffer=*/buffer.data_ptr<int32_t>(),
437 /*ldc=*/N,
438 /*outProcess=*/outputProcObj,
439 /*thread_id=*/task_id,
440 /*num_threads=*/num_tasks);
441 }
442 }
443 });
444 return output;
445 }
446
447 #endif // USE_FBGEMM
448
449 #ifdef USE_PYTORCH_QNNPACK
450
451 #ifdef USE_XNNPACK
452 // TODO: add per_channel support in the future when xnnp supports it
453 template <typename scalar_t, bool kReluFused>
apply_impl_xnnp(const at::Tensor & input,double output_scale,int64_t output_zero_point)454 at::Tensor PackedLinearWeightsQnnp::apply_impl_xnnp(
455 const at::Tensor& input,
456 double output_scale,
457 int64_t output_zero_point) {
458 using underlying_t = typename scalar_t::underlying;
459
460 std::lock_guard<std::mutex> lock(qnnp_mutex_);
461
462 const std::string func_name = kReluFused ? "quantized::linear_relu (xnnpack)"
463 : "quantized::linear (xnnpack)";
464 TORCH_CHECK(
465 input.dim() >= 2, func_name, ": Input tensor rank should be >= 2.");
466 TORCH_CHECK(
467 !per_channel(),
468 func_name,
469 ": xnnpack does not currently have per_channel support.");
470
471 const auto input_contig = input.contiguous();
472 const auto input_scale = input_contig.q_scale();
473
474 const size_t rows_w = bias_.size(0);
475 const size_t cols_w = input_contig.size(input_contig.dim() - 1);
476
477 auto status = xnn_status_invalid_state;
478
479 // Create an operator iff not already created
480 if (!xnnp_linear_op ||
481 (!this->input_scale.has_value() ||
482 this->input_scale.value() != input_scale)) {
483 // Update the input scale so we may cache the op
484 this->input_scale = input_scale;
485
486 xnn_operator_t xnnp_op = nullptr;
487
488 const float* weight_scales_data = w_scales.const_data_ptr<float>();
489
490 // prepare weights
491 underlying_t w_zp = static_cast<underlying_t>(
492 orig_weight.q_zero_point() +
493 (std::is_same<underlying_t, uint8_t>::value ? 128 : 0));
494
495 at::Tensor xnnp_weight = at::_empty_affine_quantized(
496 orig_weight.sizes(),
497 c10::CppTypeToScalarType<scalar_t>::value,
498 weight_scales_data[0],
499 w_zp);
500
501 // copy from the original weight and take care of dtype change if necessary
502 at::native::xnnp_utils::q8_copy_int8_weight_and_add_offset<scalar_t>(
503 orig_weight, xnnp_weight);
504
505 // Original bias was float, so we requantize it here.
506 at::Tensor qbias = quant_utils::QuantizeBias(false, bias_, orig_weight, input_scale);
507
508 // output limits
509 auto output_min = kReluFused
510 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
511 ? activationLimits<underlying_t>(output_scale, output_zero_point, Activation::RELU).first
512 : std::numeric_limits<underlying_t>::min();
513 auto output_max = kReluFused
514 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
515 ? activationLimits<underlying_t>(output_scale, output_zero_point, Activation::RELU).second
516 : std::numeric_limits<underlying_t>::max();
517
518 // Create an operator
519 status = at::native::xnnp_utils::xnnp_create_fully_connected_nc(
520 cols_w, /* input_channels */
521 rows_w, /* output_channels */
522 cols_w, /* input_stride */
523 rows_w, /* output_stride */
524 input_contig.q_zero_point(),
525 input_contig.q_scale(),
526 w_zp,
527 weight_scales_data[0],
528 reinterpret_cast<const underlying_t*>(
529 xnnp_weight.template data_ptr<scalar_t>()),
530 reinterpret_cast<int32_t*>(qbias.data_ptr<c10::qint32>()),
531 output_zero_point,
532 output_scale,
533 output_min,
534 output_max,
535 0, /* flags */
536 &xnnp_op);
537 xnnp_linear_op = xnnpack_operator(xnnp_op);
538
539 TORCH_CHECK(
540 status == xnn_status_success,
541 func_name,
542 ": xnn create operator failed(",
543 status,
544 ")");
545 }
546
547 /*
548 * Allocate output Tensor and a buffer for XNNPACK to use
549 * The resulting matrix here is 2-D, let's view it with the original
550 * left hand dimensions of the input. Here are two examples:
551 * 1. If the input tensor is {M, K}, the output tensor is {M, N}.
552 * 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
553 */
554 std::vector<int64_t> out_sizes = input.sizes().vec();
555 out_sizes.back() = static_cast<int64_t>(rows_w);
556 at::Tensor output = at::native::empty_affine_quantized(
557 out_sizes,
558 c10::CppTypeToScalarType<scalar_t>::value,
559 std::nullopt /* layout */,
560 c10::kCPU,
561 std::nullopt /* pin_memory */,
562 output_scale,
563 output_zero_point,
564 input.suggest_memory_format());
565
566 // calculate batch_size
567 size_t rows_input = 1;
568 for (const auto i : c10::irange(input_contig.dim() - 1)) {
569 rows_input *= input_contig.size(i);
570 }
571
572 // Reshape the operator
573 status = at::native::xnnp_utils::xnnp_reshape_fully_connected_nc(
574 xnnp_linear_op.get(),
575 rows_input, /* batch_size */
576 caffe2::pthreadpool_());
577
578 // Setup the operator
579 status = at::native::xnnp_utils::xnnp_setup_fully_connected_nc(
580 xnnp_linear_op.get(),
581 reinterpret_cast<const underlying_t*>(
582 input_contig.template data_ptr<scalar_t>()),
583 reinterpret_cast<underlying_t*>(output.template data_ptr<scalar_t>())
584 );
585
586 TORCH_CHECK(
587 status == xnn_status_success,
588 func_name,
589 ": xnn setup operator failed(",
590 status,
591 ")");
592
593 // Run the operator
594 status = xnn_run_operator(
595 xnnp_linear_op.get(), // Linear op
596 caffe2::pthreadpool_() // threadpool
597 );
598 TORCH_CHECK(
599 status == xnn_status_success,
600 func_name,
601 ": xnn run operator failed(",
602 status,
603 ")");
604
605 return output;
606 }
607 #endif // USE_XNNPACK
608
609 template <bool ReluFused>
apply_impl(at::Tensor input,double output_scale,int64_t output_zero_point)610 at::Tensor PackedLinearWeightsQnnp::apply_impl(
611 at::Tensor input,
612 double output_scale,
613 int64_t output_zero_point) {
614 TORCH_CHECK(
615 input.dim() >= 2,
616 "quantized::linear(): Input tensor rank should be >= 2");
617 TORCH_CHECK(input.scalar_type() == c10::kQUInt8,
618 "quantized::linear (qnnpack): Expected input data type ",
619 toString(c10::kQUInt8),
620 " but got ",
621 toString(input.scalar_type()));
622
623 auto input_contig = input.contiguous();
624
625 // Weight packing is not thread safe
626 std::lock_guard<std::mutex> lock(qnnp_mutex_);
627 auto packB = w.get();
628 size_t rows_w = bias_.size(0);
629 size_t cols_w = input_contig.size(input_contig.dim() - 1);
630 auto input_scale = input_contig.q_scale();
631
632 if (!this->input_scale.has_value() ||
633 this->input_scale.value() != input_scale) {
634 // Get the original weight and adjust it to uint8 from int8
635 auto weight_contig = orig_weight;
636 auto bias_fp32 = bias_;
637 int8_t* w_data = (int8_t*)weight_contig.data_ptr<c10::qint8>();
638
639 float* weight_scales_data = w_scales.data_ptr<float>();
640 // We calculate requant scale here as the vector holding the requant scale
641 // is owned by this module. The pointer is then passed to qnnpack backend.
642 generate_requantization_scales(
643 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
644 w_scales, input_scale, output_scale, requantization_scales);
645
646 at::Tensor qnnp_weight = at::_empty_affine_quantized(
647 weight_contig.sizes(),
648 at::device(c10::kCPU).dtype(c10::kQUInt8),
649 weight_scales_data[0],
650 w_zero_points[0]);
651 auto* qnnp_w_data = qnnp_weight.data_ptr<c10::quint8>();
652 auto wt_numel = weight_contig.numel();
653 for (const auto i : c10::irange(wt_numel)) {
654 qnnp_w_data[i] = static_cast<c10::quint8>(w_data[i] + 128);
655 }
656 // Original bias was float, so we requantize it here.
657 const bool is_per_channel = orig_weight.qscheme() == at::kPerChannelAffine;
658 at::Tensor qbias = quant_utils::QuantizeBias(is_per_channel, bias_fp32, weight_contig, input_scale);
659
660 // Update the input scale to not pack again.
661 this->input_scale = input_scale;
662 w.reset();
663 w = std::make_unique<qnnpack::PackBMatrix>(
664 cols_w /* input_channels */,
665 rows_w /* output_channels */,
666 w_zero_points.data(),
667 requantization_scales.data(),
668 reinterpret_cast<uint8_t*>(qnnp_w_data),
669 reinterpret_cast<int32_t*>(qbias.data_ptr<c10::qint32>()));
670 packB = w.get();
671 if (at::globalContext().releaseWeightsWhenPrepacking()) {
672 // On mobile, we release the original weight by resetting the intrusive_ptr.
673 // Calling unpack after this will throw an assertion.
674 orig_weight.reset();
675 }
676 }
677
678 size_t rows_input = 1;
679 size_t cols_input = input_contig.size(input_contig.dim() - 1);
680 for (const auto i : c10::irange(input_contig.dim() -1)) {
681 rows_input *= input_contig.size(i);
682 }
683
684 TORCH_CHECK(
685 cols_input == cols_w,
686 "quantized::linear(): input size does not match weight dimension 1 size: \
687 got ",
688 cols_input,
689 " but expected ",
690 cols_w);
691
692 // Allocate output Tensor and a buffer for QNNPACK to use
693 // The resulting matrix here is 2-D, let's view it with the original
694 // left hand dimensions of the input. Here are two examples:
695 // 1. If the input tensor is {M, K}, the output tensor is {M, N}.
696 // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
697 std::vector<int64_t> out_sizes = input.sizes().vec();
698 out_sizes.back() = static_cast<long>(rows_w);
699 at::Tensor output = at::_empty_affine_quantized(
700 out_sizes,
701 input.options(),
702 output_scale,
703 output_zero_point);
704
705 auto output_min = ReluFused
706 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
707 ? activationLimits<uint8_t>(output_scale, output_zero_point, Activation::RELU)
708 .first
709 : std::numeric_limits<uint8_t>::min();
710 auto output_max = ReluFused
711 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
712 ? activationLimits<uint8_t>(output_scale, output_zero_point, Activation::RELU)
713 .second
714 : std::numeric_limits<uint8_t>::max();
715 TORCH_INTERNAL_ASSERT(packB != nullptr, "Packed Weights are NULL");
716 const pytorch_qnnp_status runStatus = qnnpack::qnnpackLinear(
717 rows_input /* batch_size */,
718 cols_input /* input_channels */,
719 rows_w /* output_channels */,
720 input_contig.q_zero_point(),
721 w_zero_points.data(),
722 requantization_scales.data(),
723 output_zero_point,
724 output_min,
725 output_max,
726 (uint8_t*)input_contig.data_ptr<c10::quint8>(),
727 cols_input /* input_stride */,
728 packB->getPackedWeights(),
729 (uint8_t*)output.data_ptr<c10::quint8>(),
730 rows_w /* output_stride */,
731 // TODO (Ashkan): Disabling temporarily.
732 // Throws a floating point exception with OSS pthreadpool.
733 caffe2::pthreadpool_() /* threadpool */);
734
735 TORCH_INTERNAL_ASSERT(
736 runStatus == pytorch_qnnp_status_success,
737 "failed to run QNNPACK Linear operator");
738
739 return output;
740 }
741
742 #ifdef USE_XNNPACK
can_use_xnnp(c10::ScalarType dtype,bool per_channel)743 static bool can_use_xnnp(c10::ScalarType dtype, bool per_channel) {
744 if(!at::native::xnnpack::available()) {
745 return false;
746 }
747
748 bool supported_dtypes = dtype == c10::kQInt8;
749 bool invalid_config = per_channel; /* xnnp does not currently support
750 per-channel fully connected op */
751 if (supported_dtypes && invalid_config) {
752 /* don't want this to fall through to QNNPACK */
753 TORCH_CHECK(
754 false,
755 "quantized::linear (xnnpack): Unsupported config for dtype KQInt8");
756 }
757 return supported_dtypes && !invalid_config;
758 }
759 #endif // USE_XNNPACK
760
apply(at::Tensor input,double output_scale,int64_t output_zero_point)761 at::Tensor PackedLinearWeightsQnnp::apply(
762 at::Tensor input,
763 double output_scale,
764 int64_t output_zero_point) {
765 #ifdef USE_XNNPACK
766 if (can_use_xnnp(input.scalar_type(), per_channel())) {
767 return apply_impl_xnnp<c10::qint8, false>(
768 input, output_scale, output_zero_point);
769 } /* fall through for unsupported types, configs, or shapes */
770 #endif // USE_XNNPACK
771 return apply_impl<false>(std::move(input), output_scale, output_zero_point);
772 }
773
apply_relu(at::Tensor input,double output_scale,int64_t output_zero_point)774 at::Tensor PackedLinearWeightsQnnp::apply_relu(
775 at::Tensor input,
776 double output_scale,
777 int64_t output_zero_point) {
778 #ifdef USE_XNNPACK
779 if (can_use_xnnp(input.scalar_type(), per_channel())) {
780 return apply_impl_xnnp<c10::qint8, true>(
781 input, output_scale, output_zero_point);
782 } /* fall through for unsupported types, configs, or shapes */
783 #endif // USE_XNNPACK
784 return apply_impl<true>(std::move(input), output_scale, output_zero_point);
785 }
786
787 #endif // USE_PYTORCH_QNNPACK
788
789 #if AT_MKLDNN_ENABLED()
790 template <PostOps post_op>
apply_impl(at::Tensor input,double output_scale,int64_t output_zero_point,torch::List<at::Scalar> post_op_args)791 at::Tensor PackedLinearWeightsOnednn::apply_impl(
792 at::Tensor input,
793 double output_scale,
794 int64_t output_zero_point,
795 torch::List<at::Scalar> post_op_args) {
796 const int64_t dim = input.dim();
797 TORCH_CHECK(
798 dim != 0,
799 "qlinear (ONEDNN): input dim should be at least 1, but got 0");
800 TORCH_CHECK(input.scalar_type() == c10::ScalarType::QUInt8,
801 "qlinear (ONEDNN): data type of input should be QUint8.");
802
803 auto input_contig = input.expect_contiguous();
804 auto& w = *(weight_.get());
805 auto K = input.size(dim - 1), M = input.numel() / K, N = w.get_dim(1);
806 auto input_dims = {M, K};
807 auto input_data_type = dnnl::memory::data_type::u8;
808 auto input_desc = ideep::tensor::desc(input_dims, input_data_type);
809 ideep::attr_t op_attr = ideep::attr_t();
810 if (post_op == Relu) {
811 op_attr = ideep::attr_t::fuse_relu();
812 } else if (post_op == LeakyRelu) {
813 op_attr = ideep::attr_t::fuse_relu(/*scale=*/1.0f, /*alpha=*/post_op_args.get(0).to<double>());
814 } else if (post_op == Tanh) {
815 op_attr = ideep::attr_t::fuse_tanh();
816 }
817 ideep::tensor x(input_desc, input_contig->data_ptr<c10::quint8>());
818 auto dst_dims = {M, N};
819 double input_scale = input.q_scale();
820 int64_t input_zero_point = input.q_zero_point();
821 const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0/input_scale);
822 const ideep::scale_t& weights_scales = w.get_scale();
823 // Scales of ONEDNN and PyTorch are reciprocal
824 const ideep::scale_t& dst_scales = ideep::scale_t(1, 1.0/output_scale);
825 const ideep::zero_point_t& src_zero_point = ideep::zero_point_t(1, input_zero_point);
826 const ideep::zero_point_t& dst_zero_point = ideep::zero_point_t(1, output_zero_point);
827 // Compute: Use ideep::matmul_forward to support asymmetric quantization
828 // Allocate output Tensor
829 at::Tensor output = at::_empty_affine_quantized(
830 dst_dims,
831 at::device(c10::kCPU).dtype(c10::kQUInt8),
832 output_scale,
833 output_zero_point);
834 if (output.numel() == 0) {
835 return output;
836 }
837 ideep::tensor y({dst_dims, ideep::tensor::data_type::u8,
838 {output.strides().cbegin(), output.strides().cend()}},
839 output.data_ptr());
840 bool with_bias = bias_.has_value();
841 if (with_bias) {
842 // Bias might be modified outside (e.g. by quantization bias correction).
843 // If so, update the prepacked bias as well.
844 if (bias_.value().get_data_handle() != orig_bias_.value().data_ptr()) {
845 bias_.value().init(bias_.value().get_desc(), orig_bias_.value().data_ptr());
846 }
847 }
848 const auto& b = with_bias ? bias_.value() : ideep::tensor();
849 // Primitive cache is initialized when called for the first time
850 // and won't be updated afterwards.
851 int num_threads = at::get_num_threads();
852 PrimitiveCacheKey cache_key = std::make_tuple(
853 input_scale, input_zero_point, input_dims, output_scale, output_zero_point, num_threads, /*accum scale*/1.0, /*accum zero point*/0);
854 c10::call_once(*cache_initialized_flag, [&](){
855 LinearParams params;
856 ideep::matmul_forward::prepare</*is_dynamic=*/false>(
857 params, x, w, b, y,
858 src_scales, weights_scales, dst_scales,
859 src_zero_point, dst_zero_point, 1.0f, 1.0f, op_attr);
860 get_cache() = LinearPrimitiveCache(cache_key, params);
861 w = w.reorder_if_differ_in(params.pd.weights_desc());
862 });
863 if (get_cache().hit(cache_key)) {
864 LinearParams& params = get_cache().get_param();
865 ideep::matmul_forward::compute<false, false>(params, x, w, b, y);
866 } else {
867 ideep::matmul_forward::compute(x, w, b, y, src_scales, weights_scales,
868 dst_scales, src_zero_point, dst_zero_point,
869 1.0f, 1.0f, op_attr);
870 }
871 auto out_sizes = input.sizes().vec();
872 out_sizes.back() = N;
873 if (output.sizes().vec() == out_sizes)
874 return output;
875 return output.reshape(out_sizes);
876 }
877
apply(at::Tensor input,double output_scale,int64_t output_zero_point)878 at::Tensor PackedLinearWeightsOnednn::apply(
879 at::Tensor input,
880 double output_scale,
881 int64_t output_zero_point) {
882 return apply_impl<NoPostOp>(
883 std::move(input), output_scale, output_zero_point);
884 }
885
apply_relu(at::Tensor input,double output_scale,int64_t output_zero_point)886 at::Tensor PackedLinearWeightsOnednn::apply_relu(
887 at::Tensor input,
888 double output_scale,
889 int64_t output_zero_point) {
890 return apply_impl<Relu>(
891 std::move(input), output_scale, output_zero_point);
892 }
893
apply_leaky_relu(at::Tensor input,double output_scale,int64_t output_zero_point,double negative_slope)894 at::Tensor PackedLinearWeightsOnednn:: apply_leaky_relu(
895 at::Tensor input,
896 double output_scale,
897 int64_t output_zero_point,
898 double negative_slope) {
899 torch::List<at::Scalar> post_op_args =
900 {at::Scalar(negative_slope)};
901 return apply_impl<LeakyRelu>(
902 std::move(input), output_scale, output_zero_point, post_op_args);
903 }
904
apply_tanh(at::Tensor input,double output_scale,int64_t output_zero_point)905 at::Tensor PackedLinearWeightsOnednn:: apply_tanh(
906 at::Tensor input,
907 double output_scale,
908 int64_t output_zero_point) {
909 return apply_impl<Tanh>(
910 std::move(input), output_scale, output_zero_point);
911 }
912
linear_int8_with_onednn_weight(at::Tensor input,double input_scale,int64_t input_zero_point,at::Tensor onednn_weight,at::Tensor weight_scales,at::Tensor weight_zero_points,std::optional<at::Tensor> bias,double output_scale,int64_t output_zero_point,std::optional<c10::ScalarType> output_dtype,std::optional<at::Tensor> other,double other_scale,int64_t other_zero_point,const c10::string_view & binary_post_op,double binary_alpha,const c10::string_view & unary_post_op,torch::List<std::optional<at::Scalar>> & unary_post_op_args,c10::string_view & unary_post_op_algorithm)913 static at::Tensor linear_int8_with_onednn_weight(
914 at::Tensor input, // int8 CPU Tensor, not QTensor
915 double input_scale,
916 int64_t input_zero_point,
917 at::Tensor onednn_weight, // int8 tensor from MkldnnCPU
918 at::Tensor weight_scales,
919 at::Tensor weight_zero_points,
920 std::optional<at::Tensor> bias, // plain tensor
921 double output_scale,
922 int64_t output_zero_point,
923 std::optional<c10::ScalarType> output_dtype,
924 std::optional<at::Tensor> other, // extra input for binary post-op
925 double other_scale,
926 int64_t other_zero_point,
927 const c10::string_view& binary_post_op, // e.g. "none", "sum", "add"
928 double binary_alpha,
929 const c10::string_view& unary_post_op, // e.g. "none", "relu"
930 torch::List<std::optional<at::Scalar>>& unary_post_op_args,
931 c10::string_view& unary_post_op_algorithm) {
932 using ideep::tensor;
933 const int64_t dim = input.dim();
934 TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte,
935 "qlinear with mkldnn tensor: data type of input should be uint8 (unsigned char).");
936 TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char,
937 "qlinear with mkldnn tensor: data type of weight should be int8 (char).");
938 TORCH_CHECK(
939 weight_scales.scalar_type() == c10::ScalarType::Float, "weight scales should be dtype c10::ScalarType::Float.");
940 TORCH_CHECK(
941 binary_alpha == 1.0f, "onednn qlinear: alpha != 1 for binary post op is not yet supported.");
942 bool fp32_output = output_dtype.has_value() && (output_dtype.value() == c10::kFloat);
943 bool bf16_output = output_dtype.has_value() && (output_dtype.value() == c10::kBFloat16);
944 if (fp32_output || bf16_output) {
945 TORCH_CHECK(
946 output_scale == 1.0f && output_zero_point == 0, "onednn qlinear: expect scale=1 and zero point=0 for fp32 output");
947 }
948 if (binary_post_op != "none") {
949 /* Supported cases for binary post op:
950 +-------------------+--------------+---------------+
951 | Extra input dtype | Output dtype | Post op |
952 +-------------------+--------------+---------------+
953 | Fp32/bf16 | fp32/bf16 | sum |
954 +-------------------+--------------+---------------+
955 | Fp32/bf16 | int8 | add |
956 +-------------------+--------------+---------------+
957 | int8 | fp32/bf16 | not supported |
958 +-------------------+--------------+---------------+
959 | int8 | int8 | sum |
960 +-------------------+--------------+---------------+
961 */
962 TORCH_CHECK(other.has_value(), "onednn qlinear: the extra input is missing for post op ", binary_post_op);
963 if (fp32_output || bf16_output) {
964 TORCH_CHECK(
965 other_scale == 1.0f && other_zero_point == 0,
966 "onednn qlinear: expect extra input scale = 1.0 and zero point = 0 when output dtype is ", output_dtype.value(),
967 ", but got ", other_scale, " and ", other_zero_point, ", respectively"
968 );
969 }
970 if (binary_post_op == "sum") {
971 auto expected_dtype = output_dtype.has_value() ? output_dtype.value() : c10::kByte;
972 TORCH_CHECK(
973 other.value().scalar_type() == expected_dtype,
974 "onednn qlinear: the dtype of extra input for binary post op should be ", expected_dtype,
975 " (same as output dtype), but got ", other.value().scalar_type()
976 );
977 }
978 }
979
980 // If the input has more than two dimensions, we will reshape it to a 2-dimensional form
981 // for calculation and subsequently reshape the output back.
982 auto input_contig =
983 dim == 2 ? input.contiguous() : input.reshape({-1, input.size(dim - 1)}).contiguous();
984
985 auto src = at::native::itensor_from_tensor(input_contig);
986 auto packed_weight = at::native::itensor_from_mkldnn(onednn_weight);
987 int64_t K = input.size(dim - 1), M = input.numel() / K, N = packed_weight.get_dim(1);
988
989 auto output_size = input.sizes().vec();
990 output_size[dim - 1] = N;
991
992 std::optional<ideep::tensor> onednn_bias{std::nullopt};
993 bool with_bias = bias.has_value();
994 at::Tensor bias_val_float;
995 if (with_bias) {
996 bias_val_float = bias.value().to(at::kFloat);
997 if (bias_val_float.dim() == 1) {
998 auto b_reshape = bias_val_float.reshape({1, bias_val_float.size(0)});
999 onednn_bias = at::native::itensor_view_from_dense(b_reshape);
1000 } else {
1001 onednn_bias = at::native::itensor_view_from_dense(bias_val_float);
1002 }
1003 }
1004 std::vector<int64_t> src_dims = {M, K};
1005 std::vector<int64_t> dst_dims = {M, N};
1006 at::Tensor output = binary_post_op == "sum" ?
1007 other.value() :
1008 at::empty(
1009 dst_dims,
1010 device(c10::kCPU)
1011 .dtype(fp32_output ? c10::kFloat : (bf16_output ? c10::kBFloat16 : c10::kByte))
1012 );
1013 if (output.numel() == 0) {
1014 return output;
1015 }
1016 tensor dst = at::native::itensor_view_from_dense(output);
1017 static tensor empty_tensor;
1018 static tensor::desc empty_tensor_desc;
1019 tensor src1 = binary_post_op == "add" ?
1020 at::native::itensor_view_from_dense(other.value().reshape({-1, other.value().size(dim - 1)})) :
1021 empty_tensor;
1022
1023 // Create onednn primitive
1024 auto src_desc = tensor::desc(src_dims, ideep::data_type::u8, ideep::format_tag::any);
1025 auto weights_desc = packed_weight.get_desc();
1026 auto dst_dtype = dst.get_data_type();
1027 auto dst_desc = tensor::desc(dst_dims, dst_dtype, ideep::format_tag::any);
1028 auto bias_desc = with_bias ?
1029 tensor::desc(onednn_bias.value().get_dims(), ideep::data_type::f32, ideep::format_tag::any) :
1030 empty_tensor_desc;
1031 // Get op attr for primitive
1032 // Note: output_scale & output_zero_point are for re-quantization of the final output.
1033 // And other_scale & other_zero_point are for dequantization of other.
1034 auto other_desc = binary_post_op == "add" ? src1.get_desc() : empty_tensor_desc;
1035 auto op_attr = onednn_utils::create_attr_by_post_op(
1036 binary_post_op,
1037 binary_alpha,
1038 other_scale,
1039 other_zero_point,
1040 other_desc,
1041 unary_post_op,
1042 unary_post_op_args,
1043 unary_post_op_algorithm
1044 );
1045 if (input_scale != 1.0f) {
1046 op_attr.set_scales_mask(DNNL_ARG_SRC, 0);
1047 }
1048 if (input_zero_point != 0) {
1049 op_attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
1050 }
1051 op_attr.set_scales_mask(DNNL_ARG_WEIGHTS, ideep::utils::op_scale_mask(weight_scales.numel()));
1052 if (output_scale != 1.0f) {
1053 op_attr.set_scales_mask(DNNL_ARG_DST, 0);
1054 }
1055 if (output_zero_point != 0) {
1056 op_attr.set_zero_points_mask(DNNL_ARG_DST, 0);
1057 }
1058 op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
1059 auto engine = ideep::engine::cpu_engine();
1060 auto primitive_desc = with_bias ?
1061 dnnl::matmul::primitive_desc(engine, src_desc, weights_desc, bias_desc, dst_desc, op_attr) :
1062 dnnl::matmul::primitive_desc(engine, src_desc, weights_desc, dst_desc, op_attr);
1063 auto primitive = dnnl::matmul(primitive_desc);
1064
1065 // Reorder weight if needed
1066 auto expected_weight = packed_weight.reorder_if_differ_in(primitive_desc.weights_desc());
1067
1068 // Prepare args and execute primitive
1069 tensor scratchpad(primitive_desc.scratchpad_desc());
1070 ideep::exec_args args;
1071 args.insert({DNNL_ARG_SRC, src});
1072 args.insert({DNNL_ARG_WEIGHTS, expected_weight});
1073 args.insert({DNNL_ARG_DST, dst});
1074 args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
1075 if (with_bias) {
1076 args.insert({DNNL_ARG_BIAS, onednn_bias.value()});
1077 }
1078 tensor src_scales_t = tensor(ideep::scale_t(1, input_scale));
1079 tensor wei_scales_t = at::native::itensor_from_tensor(weight_scales);
1080 tensor dst_scales_t = tensor(ideep::scale_t(1, output_scale));
1081 tensor src_zp_t = tensor(ideep::zero_point_t(1, input_zero_point));
1082 tensor dst_zp_t = tensor(ideep::zero_point_t(1, output_zero_point));
1083 if (input_scale != 1.0f) {
1084 args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_t});
1085 }
1086 if (output_scale != 1.0f) {
1087 args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_t});
1088 }
1089 args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_t});
1090 if (input_zero_point != 0) {
1091 args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_t});
1092 }
1093 if (output_zero_point != 0) {
1094 args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_t});
1095 }
1096 if (binary_post_op == "add") {
1097 args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, src1});
1098 }
1099 primitive.execute(ideep::stream::default_stream(), args);
1100 return dim == 2 ? output : output.reshape(output_size);
1101 }
1102 #endif // #if AT_MKLDNN_ENABLED()
1103
1104 namespace at {
1105 namespace native {
1106 namespace {
1107
1108 template <bool ReluFused>
1109 class QLinearInt8 final {
1110 public:
run(at::Tensor input,const c10::intrusive_ptr<LinearPackedParamsBase> & packed_weight,double output_scale,int64_t output_zero_point)1111 static at::Tensor run(
1112 at::Tensor input,
1113 const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
1114 double output_scale,
1115 int64_t output_zero_point) {
1116 if (ReluFused) {
1117 return packed_weight->apply_relu(
1118 std::move(input), output_scale, output_zero_point);
1119 } else {
1120 return packed_weight->apply(
1121 std::move(input), output_scale, output_zero_point);
1122 }
1123 }
1124 };
1125
1126 class QLinearLeakyReluInt8 final {
1127 public:
run(at::Tensor input,const c10::intrusive_ptr<LinearPackedParamsBase> & packed_weight,double output_scale,int64_t output_zero_point,double negative_slope)1128 static at::Tensor run(
1129 at::Tensor input,
1130 const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
1131 double output_scale,
1132 int64_t output_zero_point,
1133 double negative_slope) {
1134 #if AT_MKLDNN_ENABLED() || !defined(STRIP_ERROR_MESSAGES)
1135 auto& ctx = at::globalContext();
1136 #endif
1137 #if AT_MKLDNN_ENABLED()
1138 if (ctx.qEngine() == at::QEngine::ONEDNN) {
1139 return dynamic_cast<PackedLinearWeightsOnednn*>(packed_weight.get())->apply_leaky_relu(
1140 std::move(input), output_scale, output_zero_point, negative_slope);
1141 }
1142 #endif
1143 TORCH_CHECK(
1144 false,
1145 "Didn't find engine for operation quantized::linear_leaky_relu ",
1146 toString(ctx.qEngine()));
1147 }
1148 };
1149
1150
1151 class QLinearTanhInt8 final {
1152 public:
run(at::Tensor input,const c10::intrusive_ptr<LinearPackedParamsBase> & packed_weight,double output_scale,int64_t output_zero_point)1153 static at::Tensor run(
1154 at::Tensor input,
1155 const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
1156 double output_scale,
1157 int64_t output_zero_point) {
1158 #if AT_MKLDNN_ENABLED() || !defined(STRIP_ERROR_MESSAGES)
1159 auto& ctx = at::globalContext();
1160 #endif
1161 #if AT_MKLDNN_ENABLED()
1162 if (ctx.qEngine() == at::QEngine::ONEDNN) {
1163 return dynamic_cast<PackedLinearWeightsOnednn*>(packed_weight.get())->apply_tanh(
1164 std::move(input), output_scale, output_zero_point);
1165 }
1166 #endif
1167 TORCH_CHECK(
1168 false,
1169 "Didn't find engine for operation quantized::linear_tanh ",
1170 toString(ctx.qEngine()));
1171 }
1172 };
1173
1174 template <bool ReluFused>
1175 class QLinearInt8FusedQDQ final {
1176 public:
run(at::Tensor input,double input_scale,int64_t input_zero_point,const c10::intrusive_ptr<LinearPackedParamsBase> & packed_weight)1177 static at::Tensor run(
1178 at::Tensor input,
1179 double input_scale,
1180 int64_t input_zero_point,
1181 const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight) {
1182 if (ReluFused) {
1183 return packed_weight->apply_with_input_q_dq_qweight_dq_relu_output_fp32(
1184 std::move(input), input_scale, input_zero_point);
1185 } else {
1186 return packed_weight->apply_with_input_q_dq_qweight_dq_output_fp32(
1187 std::move(input), input_scale, input_zero_point);
1188 }
1189 }
1190 };
1191
1192 class QLinearOnednn final {
1193 public:
run_pointwise(Tensor act,double act_scale,int64_t act_zero_point,Tensor onednn_weight,Tensor weight_scales,Tensor weight_zero_points,std::optional<Tensor> bias,double output_scale,int64_t output_zero_point,std::optional<c10::ScalarType> output_dtype,c10::string_view post_op_name,torch::List<std::optional<at::Scalar>> post_op_args,c10::string_view post_op_algorithm)1194 static Tensor run_pointwise(
1195 Tensor act, // int8 CPU tensor, not QTensor
1196 double act_scale,
1197 int64_t act_zero_point,
1198 Tensor onednn_weight, // int8 tensor from MkldnnCPU
1199 Tensor weight_scales,
1200 Tensor weight_zero_points,
1201 std::optional<Tensor> bias,
1202 double output_scale,
1203 int64_t output_zero_point,
1204 std::optional<c10::ScalarType> output_dtype,
1205 c10::string_view post_op_name,
1206 torch::List<std::optional<at::Scalar>> post_op_args,
1207 c10::string_view post_op_algorithm) {
1208 #if AT_MKLDNN_ENABLED()
1209 static std::optional<at::Tensor> other = std::nullopt;
1210 static const c10::string_view binary_post_op = "none";
1211 return linear_int8_with_onednn_weight(
1212 act, act_scale, act_zero_point,
1213 onednn_weight, weight_scales, weight_zero_points,
1214 bias, output_scale, output_zero_point, output_dtype,
1215 other, /*other scale*/1.0, /*other zp*/0,
1216 binary_post_op, /*binary alpha*/1.0,
1217 post_op_name, post_op_args, post_op_algorithm
1218 );
1219 #endif
1220 TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)");
1221 }
1222
run_pointwise_tensor(Tensor act,Tensor act_scale,Tensor act_zero_point,Tensor onednn_weight,Tensor weight_scales,Tensor weight_zero_points,std::optional<Tensor> bias,double output_scale,int64_t output_zero_point,std::optional<c10::ScalarType> output_dtype,c10::string_view post_op_name,torch::List<std::optional<at::Scalar>> post_op_args,c10::string_view post_op_algorithm)1223 static Tensor run_pointwise_tensor(
1224 Tensor act, // int8 CPU tensor, not QTensor
1225 Tensor act_scale,
1226 Tensor act_zero_point,
1227 Tensor onednn_weight, // int8 tensor from MkldnnCPU
1228 Tensor weight_scales,
1229 Tensor weight_zero_points,
1230 std::optional<Tensor> bias,
1231 double output_scale,
1232 int64_t output_zero_point,
1233 std::optional<c10::ScalarType> output_dtype,
1234 c10::string_view post_op_name,
1235 torch::List<std::optional<at::Scalar>> post_op_args,
1236 c10::string_view post_op_algorithm) {
1237 #if AT_MKLDNN_ENABLED()
1238 TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
1239 "onednn int8 linear: act scale/zp size should be 1");
1240 static std::optional<at::Tensor> other = std::nullopt;
1241 static const c10::string_view binary_post_op = "none";
1242 return linear_int8_with_onednn_weight(
1243 act, act_scale.item().toDouble(), act_zero_point.item().toLong(),
1244 onednn_weight, weight_scales, weight_zero_points,
1245 bias, output_scale, output_zero_point, output_dtype,
1246 other, /*other scale*/1.0, /*other zp*/0,
1247 binary_post_op, /*binary alpha*/1.0,
1248 post_op_name, post_op_args, post_op_algorithm
1249 );
1250 #endif
1251 TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)");
1252 }
1253
run_pointwise_binary(Tensor act,double act_scale,int64_t act_zero_point,Tensor onednn_weight,Tensor weight_scales,Tensor weight_zero_points,std::optional<at::Tensor> other,std::optional<Tensor> bias,double output_scale,int64_t output_zero_point,std::optional<c10::ScalarType> output_dtype,double other_scale,int64_t other_zero_point,c10::string_view binary_post_op,double binary_alpha,c10::string_view unary_post_op,torch::List<std::optional<at::Scalar>> unary_post_op_args,c10::string_view unary_post_op_algorithm)1254 static Tensor run_pointwise_binary(
1255 Tensor act, // int8 CPU tensor, not QTensor
1256 double act_scale,
1257 int64_t act_zero_point,
1258 Tensor onednn_weight, // int8 tensor from MkldnnCPU
1259 Tensor weight_scales,
1260 Tensor weight_zero_points,
1261 std::optional<at::Tensor> other, // extra input for binary post-op
1262 std::optional<Tensor> bias,
1263 double output_scale,
1264 int64_t output_zero_point,
1265 std::optional<c10::ScalarType> output_dtype,
1266 double other_scale,
1267 int64_t other_zero_point,
1268 c10::string_view binary_post_op, // e.g. "none", "sum", "add"
1269 double binary_alpha,
1270 c10::string_view unary_post_op, // e.g. "none", "relu"
1271 torch::List<std::optional<at::Scalar>> unary_post_op_args,
1272 c10::string_view unary_post_op_algorithm) {
1273 #if AT_MKLDNN_ENABLED()
1274 return linear_int8_with_onednn_weight(
1275 act, act_scale, act_zero_point,
1276 onednn_weight, weight_scales, weight_zero_points,
1277 bias, output_scale, output_zero_point, output_dtype,
1278 other, other_scale, other_zero_point,
1279 binary_post_op, binary_alpha,
1280 unary_post_op, unary_post_op_args, unary_post_op_algorithm
1281 );
1282 #endif
1283 TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)");
1284 }
1285
run_pointwise_binary_tensor(Tensor act,Tensor act_scale,Tensor act_zero_point,Tensor onednn_weight,Tensor weight_scales,Tensor weight_zero_points,std::optional<at::Tensor> other,std::optional<Tensor> bias,double output_scale,int64_t output_zero_point,std::optional<c10::ScalarType> output_dtype,double other_scale,int64_t other_zero_point,c10::string_view binary_post_op,double binary_alpha,c10::string_view unary_post_op,torch::List<std::optional<at::Scalar>> unary_post_op_args,c10::string_view unary_post_op_algorithm)1286 static Tensor run_pointwise_binary_tensor(
1287 Tensor act, // int8 CPU tensor, not QTensor
1288 Tensor act_scale,
1289 Tensor act_zero_point,
1290 Tensor onednn_weight, // int8 tensor from MkldnnCPU
1291 Tensor weight_scales,
1292 Tensor weight_zero_points,
1293 std::optional<at::Tensor> other, // extra input for binary post-op
1294 std::optional<Tensor> bias,
1295 double output_scale,
1296 int64_t output_zero_point,
1297 std::optional<c10::ScalarType> output_dtype,
1298 double other_scale,
1299 int64_t other_zero_point,
1300 c10::string_view binary_post_op, // e.g. "none", "sum", "add"
1301 double binary_alpha,
1302 c10::string_view unary_post_op, // e.g. "none", "relu"
1303 torch::List<std::optional<at::Scalar>> unary_post_op_args,
1304 c10::string_view unary_post_op_algorithm) {
1305 #if AT_MKLDNN_ENABLED()
1306 TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
1307 "onednn int8 linear: act scale/zp size should be 1");
1308 return linear_int8_with_onednn_weight(
1309 act, act_scale.item().toDouble(), act_zero_point.item().toLong(),
1310 onednn_weight, weight_scales, weight_zero_points,
1311 bias, output_scale, output_zero_point, output_dtype,
1312 other, other_scale, other_zero_point,
1313 binary_post_op, binary_alpha,
1314 unary_post_op, unary_post_op_args, unary_post_op_algorithm
1315 );
1316 #endif
1317 TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)");
1318 }
1319 };
1320
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)1321 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
1322 register_linear_params();
1323 m.impl(TORCH_SELECTIVE_NAME("quantized::linear"), TORCH_FN(QLinearInt8<false>::run));
1324 m.impl(TORCH_SELECTIVE_NAME("quantized::linear_relu"), TORCH_FN(QLinearInt8<true>::run));
1325 m.impl(TORCH_SELECTIVE_NAME("quantized::linear_leaky_relu"), TORCH_FN(QLinearLeakyReluInt8::run));
1326 m.impl(TORCH_SELECTIVE_NAME("quantized::linear_tanh"), TORCH_FN(QLinearTanhInt8::run));
1327 }
1328
TORCH_LIBRARY_IMPL(_quantized,QuantizedCPU,m)1329 TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) {
1330 register_linear_params();
1331 m.impl(TORCH_SELECTIVE_NAME("_quantized::linear"), TORCH_FN(QLinearInt8<false>::run));
1332 }
1333
TORCH_LIBRARY_IMPL(quantized,CPU,m)1334 TORCH_LIBRARY_IMPL(quantized, CPU, m) {
1335 m.impl(TORCH_SELECTIVE_NAME("quantized::linear_with_input_q_dq_qweight_dq_output_fp32"), TORCH_FN(QLinearInt8FusedQDQ<false>::run));
1336 m.impl(TORCH_SELECTIVE_NAME("quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32"), TORCH_FN(QLinearInt8FusedQDQ<true>::run));
1337 }
1338
TORCH_LIBRARY_IMPL(onednn,MkldnnCPU,m)1339 TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) {
1340 m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise"),
1341 TORCH_FN(QLinearOnednn::run_pointwise));
1342 m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.tensor"),
1343 TORCH_FN(QLinearOnednn::run_pointwise_tensor));
1344 m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary"),
1345 TORCH_FN(QLinearOnednn::run_pointwise_binary));
1346 m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary_tensor"),
1347 TORCH_FN(QLinearOnednn::run_pointwise_binary_tensor));
1348 }
1349
1350 } // namespace
1351 } // namespace native
1352 } // namespace at
1353