xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/LinearUnpackImpl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Context.h>
4 #include <ATen/cpp_custom_type_hack.h>
5 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
6 #include <ATen/native/quantized/PackedParams.h>
7 #include <ATen/native/quantized/cpu/OnednnUtils.h>
8 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
9 #include <torch/custom_class.h>
10 #include <torch/library.h>
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/_empty_affine_quantized.h>
17 #include <ATen/ops/_empty_per_channel_affine_quantized.h>
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/from_blob.h>
20 #endif
21 
22 int register_linear_params();
23 
24 #ifdef USE_FBGEMM
unpack()25 std::tuple<at::Tensor, std::optional<at::Tensor>> PackedLinearWeight::unpack() {
26   auto packB = w.get();
27 
28   int64_t N = static_cast<int64_t>(packB->numCols());
29   int64_t K = static_cast<int64_t>(packB->numRows());
30 
31   at::Tensor weight_origin;
32   if (q_scheme == c10::kPerTensorAffine) {
33     weight_origin = at::_empty_affine_quantized(
34         {N, K}, at::device(c10::kCPU).dtype(c10::kQInt8), w_scale[0], w_zp[0]);
35   } else if (q_scheme == c10::kPerChannelAffine) {
36     auto scales = at::from_blob(
37         w_scale.data(), w_scale.size(), device(c10::kCPU).dtype(c10::kFloat));
38     auto zero_points = at::from_blob(
39         w_zp.data(), w_zp.size(), device(c10::kCPU).dtype(c10::kInt));
40 
41     weight_origin = at::_empty_per_channel_affine_quantized(
42         {N, K},
43         scales.toType(c10::kDouble),
44         zero_points.toType(c10::kLong),
45         0, // The output channel axis is 0
46         device(c10::kCPU).dtype(c10::kQInt8));
47   }
48 
49   int8_t* weight_ptr_int8 =
50       reinterpret_cast<int8_t*>(weight_origin.data_ptr<c10::qint8>());
51 
52   // packB->printPackedMatrix("packedB inside fbgemm_unpack
53   // (QLinearUnpackWeightInt8): ");
54   packB->unpack(weight_ptr_int8);
55 
56   return std::tuple<at::Tensor, std::optional<at::Tensor>>(
57       weight_origin, bias_);
58 }
59 #endif // USE_FBGEMM
60 
61 #ifdef USE_PYTORCH_QNNPACK
62 std::tuple<at::Tensor, std::optional<at::Tensor>> PackedLinearWeightsQnnp::
unpack()63     unpack() {
64   if (orig_weight.defined()) {
65     return std::tuple<at::Tensor, std::optional<at::Tensor>>(
66         orig_weight, bias_);
67   } else {
68     // Unpacking requires reverting *make_zero_points_and_scales_tensor*
69     // function in QnnpackUtils.h Please refer for a detail mechanism.
70     // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h#L469
71     // w_scales and w_zero_points are different from original scales & zero
72     // points with padding & casting etc
73     at::Tensor weight_origin;
74 
75     float* weight_scales_data = w_scales.data_ptr<float>();
76     if (q_scheme == c10::kPerTensorAffine) {
77       weight_origin = at::_empty_affine_quantized(
78           weight_sizes,
79           at::device(c10::kCPU).dtype(c10::kQInt8),
80           static_cast<double>(weight_scales_data[0]),
81           (int64_t)w_zero_points[0] - 128);
82     } else if (q_scheme == c10::kPerChannelAffine) {
83       auto scales = at::from_blob(
84           weight_scales_data,
85           w_scales.sizes()[0] - kPaddingChannels,
86           device(c10::kCPU).dtype(c10::kFloat));
87 
88       at::Tensor zero_points = at::empty(
89           w_zero_points.size() - kPaddingChannels, at::device(c10::kCPU).dtype(c10::kLong));
90       for (const auto i : c10::irange(zero_points.numel())) {
91         zero_points[i] = ((int64_t)w_zero_points[i] - 128);
92       }
93       weight_origin = at::_empty_per_channel_affine_quantized(
94                           weight_sizes,
95                           scales,
96                           zero_points.toType(c10::kLong),
97                           0, // The output channel axis is 0
98                           device(c10::kCPU).dtype(c10::kQInt8))
99                           .contiguous();
100     } else {
101       TORCH_INTERNAL_ASSERT(false, "Unsupported quantization scheme.");
102     }
103     int8_t* weight_ptr_int8 =
104         reinterpret_cast<int8_t*>(weight_origin.data_ptr<c10::qint8>());
105     w->unpackWeights(w_zero_points.data(), weight_ptr_int8);
106     // See for the subtraction 128
107     // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp#L319
108     auto wt_numel = weight_origin.numel();
109     for (const auto i : c10::irange(wt_numel)) {
110       weight_ptr_int8[i] = (int8_t)(weight_ptr_int8[i] - 128);
111     }
112 
113     return std::tuple<at::Tensor, std::optional<at::Tensor>>(
114         weight_origin, bias_);
115   }
116 }
117 #endif // USE_PYTORCH_QNNPACK
118 
119 #ifdef USE_FBGEMM
120 std::tuple<at::Tensor, std::optional<at::Tensor>> PackedLinearWeightFp16::
unpack()121     unpack() {
122   auto& packed_weight_ptr = w;
123 
124   auto nrows = packed_weight_ptr->numRows();
125   auto ncols = packed_weight_ptr->numCols();
126 
127   at::Tensor unpacked_weight =
128       at::empty({ncols, nrows}, at::kHalf, c10::MemoryFormat::Contiguous);
129   packed_weight_ptr->unpack(
130       static_cast<fbgemm::float16*>(unpacked_weight.data_ptr()),
131       fbgemm::matrix_op_t::Transpose);
132 
133   return std::make_tuple(unpacked_weight.to(at::kFloat), bias_);
134 }
135 #endif // USE_FBGEMM
136 
137 #if AT_MKLDNN_ENABLED()
unpack()138 std::tuple<at::Tensor, std::optional<at::Tensor>> PackedLinearWeightsOnednn::unpack() {
139   return std::tuple<at::Tensor, std::optional<at::Tensor>>(
140       orig_weight_, orig_bias_);
141 }
142 #endif // #if AT_MKLDNN_ENABLED()
143