1 #include <ATen/native/quantized/AffineQuantizer.h>
2
3
4 namespace at::native {
5
6 DEFINE_DISPATCH(quantize_tensor_per_tensor_affine_stub);
7 DEFINE_DISPATCH(quantize_tensor_per_channel_affine_stub);
8 DEFINE_DISPATCH(quantize_tensor_per_channel_float_qparams_stub);
9 DEFINE_DISPATCH(dequantize_tensor_per_tensor_affine_stub);
10 DEFINE_DISPATCH(dequantize_tensor_per_channel_affine_stub);
11 DEFINE_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub);
12 DEFINE_DISPATCH(quantize_tensor_per_tensor_affine_sub_byte_stub);
13 DEFINE_DISPATCH(dequantize_tensor_per_tensor_affine_sub_byte_stub);
14
15 namespace {
16
checkRoundingMode(const std::string & fn_name)17 void checkRoundingMode(const std::string& fn_name) {
18 // Disabling this warning message for now as it is printed incorrectly. Need
19 // to fix
20
21 /* TORCH_WARN_ONCE(
22 std::fegetround() != FE_TONEAREST,
23 fn_name,
24 " current rounding mode is not set to round-to-nearest-ties-to-even
25 (FE_TONEAREST). This will cause accuracy issues in quantized models.");
26 */
27 return;
28 }
29
checkFloatTensor(const std::string & fn_name,const Tensor & t)30 void checkFloatTensor(const std::string& fn_name, const Tensor& t) {
31 TORCH_CHECK(
32 t.scalar_type() == kFloat, fn_name, " expects a Float Tensor, got ",
33 t.scalar_type());
34 }
35
checkSameDevice(const std::string & fn_name,const Tensor & t1,const Tensor & t2)36 void checkSameDevice(
37 const std::string& fn_name,
38 const Tensor& t1,
39 const Tensor& t2) {
40 TORCH_CHECK(
41 t1.device() == t2.device(),
42 fn_name,
43 " expects a quantized and float tensors to be on the same device.");
44 }
45
46 template <typename T>
checkQuantizedTensor(const std::string & fn_name,const Tensor & t)47 void checkQuantizedTensor(const std::string& fn_name, const Tensor& t) {
48 TORCH_CHECK(t.is_quantized(), fn_name, " expects a quantized Tensor.");
49 TORCH_CHECK(
50 t.scalar_type() == caffe2::TypeMeta::Make<T>(),
51 fn_name,
52 " expects a ",
53 caffe2::TypeMeta::Make<T>(),
54 " Tensor, got ",
55 t.scalar_type());
56 }
57
58 template <typename T>
checkZeroPoint(const std::string & fn_name,int64_t zero_point)59 void checkZeroPoint(const std::string& fn_name, int64_t zero_point) {
60 TORCH_CHECK(
61 zero_point <= std::numeric_limits<T>::max(),
62 fn_name,
63 " zero_point ",
64 zero_point,
65 " is above upper bound.");
66 TORCH_CHECK(
67 zero_point >= std::numeric_limits<T>::min(),
68 fn_name,
69 " zero_point ",
70 zero_point,
71 " is below lower bound.");
72 }
73
74 template <typename T>
checkZeroPoints(const std::string & fn_name,const Tensor & zero_points)75 void checkZeroPoints(const std::string& fn_name, const Tensor& zero_points) {
76 auto zero_points_data = zero_points.data_ptr<int64_t>();
77 for (const auto i : c10::irange(zero_points.numel())) {
78 checkZeroPoint<T>(fn_name, zero_points_data[i]);
79 }
80 }
81
checkSameSize(const std::string & fn_name,const Tensor & qt,const Tensor & rt)82 void checkSameSize(
83 const std::string& fn_name,
84 const Tensor& qt,
85 const Tensor& rt) {
86 TORCH_CHECK(
87 qt.sizes().equals(rt.sizes()),
88 fn_name,
89 " only works with Tensors with the same shape");
90 }
91
checkPerChannelParamsSize(const Tensor & rtensor,int64_t axis,const Tensor & scales,const Tensor & zero_points)92 void checkPerChannelParamsSize(
93 const Tensor& rtensor,
94 int64_t axis,
95 const Tensor& scales,
96 const Tensor& zero_points
97 ) {
98 int64_t channel = rtensor.size(axis);
99 TORCH_CHECK(
100 channel == int64_t(scales.numel()),
101 "length of scales must equal to channel, expected ", channel, " got, ", scales.numel());
102 TORCH_CHECK(
103 channel == int64_t(zero_points.numel()),
104 "length of zero_points must equal to channel expected ", channel, " got, ", zero_points.numel());
105 }
106
107 } // anonymous namespace
108
quantize_tensor_per_tensor_affine(const Tensor & rtensor,Tensor & qtensor,double scale,int64_t zero_point)109 Tensor& quantize_tensor_per_tensor_affine(
110 const Tensor& rtensor,
111 Tensor& qtensor,
112 double scale,
113 int64_t zero_point) {
114 static constexpr auto fn_name = "quantize_tensor_per_tensor_affine";
115
116 checkRoundingMode(fn_name);
117 checkFloatTensor(fn_name, rtensor);
118 checkSameDevice(fn_name, rtensor, qtensor);
119 checkSameSize(fn_name, qtensor, rtensor);
120
121 AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() {
122 checkQuantizedTensor<scalar_t>(fn_name, qtensor);
123 checkZeroPoint<underlying_t>(fn_name, zero_point);
124 });
125
126 // Temporary solution to pack the tensor if dtype is torch.quint4x2
127 // Can move this into the fbgemm::Quantize op.
128 if (qtensor.scalar_type() == at::ScalarType::QUInt4x2 || qtensor.scalar_type() == at::ScalarType::QUInt2x4) {
129 quantize_tensor_per_tensor_affine_sub_byte_stub(
130 rtensor.device().type(), rtensor, qtensor, scale, zero_point);
131 } else {
132 quantize_tensor_per_tensor_affine_stub(
133 rtensor.device().type(), rtensor, qtensor, scale, zero_point);
134 }
135 return qtensor;
136 }
137
quantize_tensor_per_channel_affine(const Tensor & rtensor,Tensor & qtensor,const Tensor & scales,Tensor zero_points,int64_t axis)138 Tensor& quantize_tensor_per_channel_affine(
139 const Tensor& rtensor,
140 Tensor& qtensor,
141 const Tensor& scales,
142 Tensor zero_points,
143 int64_t axis) {
144 static constexpr auto fn_name = "quantize_tensor_per_channel_affine";
145
146 checkRoundingMode(fn_name);
147 checkFloatTensor(fn_name, rtensor);
148 checkSameDevice(fn_name, rtensor, qtensor);
149 checkSameSize(fn_name, qtensor, rtensor);
150
151 AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() {
152 checkQuantizedTensor<scalar_t>(fn_name, qtensor);
153 if (qtensor.device().type() != c10::DeviceType::CUDA &&
154 qtensor.device().type() != c10::DeviceType::PrivateUse1) {
155 checkZeroPoints<underlying_t>(fn_name, zero_points);
156 } // for cuda and privateuse1, this check will occur in the actual device function
157 });
158
159 TORCH_CHECK(
160 0 <= axis && axis < rtensor.dim(),
161 "Channel axis out of range in per channel affine quantization. Got: ",
162 axis,
163 "Expected: [0, ",
164 rtensor.dim(),
165 ")");
166 checkPerChannelParamsSize(rtensor, axis, scales, zero_points);
167
168 quantize_tensor_per_channel_affine_stub(
169 rtensor.device().type(), rtensor, qtensor, scales, zero_points, axis);
170 return qtensor;
171 }
172
quantize_tensor_per_channel_float_qparams(const Tensor & rtensor,Tensor & qtensor,const Tensor & scales,const Tensor & zero_points,int64_t axis)173 Tensor& quantize_tensor_per_channel_float_qparams(
174 const Tensor& rtensor,
175 Tensor& qtensor,
176 const Tensor& scales,
177 const Tensor& zero_points,
178 int64_t axis) {
179 static constexpr auto fn_name =
180 "quantize_tensor_per_channel_float_qparams";
181
182 checkRoundingMode(fn_name);
183 checkFloatTensor(fn_name, rtensor);
184 checkSameDevice(fn_name, rtensor, qtensor);
185 checkSameSize(fn_name, qtensor, rtensor);
186
187 AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() {
188 checkQuantizedTensor<scalar_t>(fn_name, qtensor);
189 });
190
191 TORCH_CHECK(
192 0 <= axis && axis < rtensor.dim(),
193 "Channel axis out of range in per channel float qparams quantization. Got: ",
194 axis,
195 "Expected: [0, ",
196 rtensor.dim(),
197 ")");
198 checkPerChannelParamsSize(rtensor, axis, scales, zero_points);
199
200 quantize_tensor_per_channel_float_qparams_stub(
201 rtensor.device().type(), rtensor, qtensor, scales, zero_points, axis);
202 return qtensor;
203 }
204
dequantize_tensor_per_tensor_affine(const Tensor & qtensor,Tensor & rtensor,double scale,int64_t zero_point)205 Tensor& dequantize_tensor_per_tensor_affine(
206 const Tensor& qtensor,
207 Tensor& rtensor,
208 double scale,
209 int64_t zero_point) {
210 static constexpr auto fn_name = "dequantize_tensor_per_tensor_affine";
211 checkFloatTensor(fn_name, rtensor);
212 checkSameDevice(fn_name, rtensor, qtensor);
213 checkSameSize(fn_name, qtensor, rtensor);
214
215 AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() {
216 checkQuantizedTensor<scalar_t>(fn_name, qtensor);
217 checkZeroPoint<underlying_t>(fn_name, zero_point);
218 });
219
220 if (qtensor.scalar_type() == at::ScalarType::QUInt4x2 || qtensor.scalar_type() == at::ScalarType::QUInt2x4) {
221 dequantize_tensor_per_tensor_affine_sub_byte_stub(
222 qtensor.device().type(), qtensor, rtensor, scale, zero_point);
223 } else {
224 dequantize_tensor_per_tensor_affine_stub(
225 qtensor.device().type(), qtensor, rtensor, scale, zero_point);
226 }
227 return rtensor;
228 }
229
dequantize_tensor_per_channel_affine(const Tensor & qtensor,Tensor & rtensor,const Tensor & scales,Tensor zero_points,int64_t axis)230 Tensor& dequantize_tensor_per_channel_affine(
231 const Tensor& qtensor,
232 Tensor& rtensor,
233 const Tensor& scales,
234 Tensor zero_points,
235 int64_t axis) {
236 static constexpr auto fn_name = "dequantize_tensor_per_channel_affine";
237
238 checkFloatTensor(fn_name, rtensor);
239 checkSameDevice(fn_name, rtensor, qtensor);
240 checkSameSize(fn_name, qtensor, rtensor);
241
242 AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() {
243 checkQuantizedTensor<scalar_t>(fn_name, qtensor);
244 if(qtensor.device().type() != c10::DeviceType::CUDA &&
245 qtensor.device().type() != c10::DeviceType::PrivateUse1){
246 checkZeroPoints<underlying_t>(fn_name, zero_points);
247 } // for cuda and privateuse1, this check will occur in the actual device function
248 });
249
250 TORCH_CHECK(
251 0 <= axis && axis < qtensor.dim(),
252 "Channel axis out of range in per channel affine dequantization. Got:",
253 axis,
254 " Expected: [0, ",
255 qtensor.dim(),
256 ")");
257 checkPerChannelParamsSize(rtensor, axis, scales, zero_points);
258
259 dequantize_tensor_per_channel_affine_stub(
260 qtensor.device().type(), qtensor, rtensor, scales, zero_points, axis);
261 return rtensor;
262 }
263
dequantize_tensor_per_channel_float_qparams(const Tensor & qtensor,Tensor & rtensor,const Tensor & scales,const Tensor & zero_points,int64_t axis)264 Tensor& dequantize_tensor_per_channel_float_qparams(
265 const Tensor& qtensor,
266 Tensor& rtensor,
267 const Tensor& scales,
268 const Tensor& zero_points,
269 int64_t axis) {
270 static constexpr auto fn_name = "dequantize_tensor_per_channel_affine";
271
272 checkFloatTensor(fn_name, rtensor);
273 checkSameDevice(fn_name, rtensor, qtensor);
274 checkSameSize(fn_name, qtensor, rtensor);
275
276 AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() {
277 checkQuantizedTensor<scalar_t>(fn_name, qtensor);
278 });
279
280 TORCH_CHECK(
281 0 <= axis && axis < qtensor.dim(),
282 "Channel axis out of range in per channel float qparams dequantization. Got:",
283 axis,
284 " Expected: [0, ",
285 qtensor.dim(),
286 ")");
287 checkPerChannelParamsSize(rtensor, axis, scales, zero_points);
288
289 dequantize_tensor_per_channel_float_qparams_stub(
290 qtensor.device().type(), qtensor, rtensor, scales, zero_points, axis);
291 return rtensor;
292 }
293
294 } // namespace at::native
295