xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/AffineQuantizer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/quantized/AffineQuantizer.h>
2 
3 
4 namespace at::native {
5 
6 DEFINE_DISPATCH(quantize_tensor_per_tensor_affine_stub);
7 DEFINE_DISPATCH(quantize_tensor_per_channel_affine_stub);
8 DEFINE_DISPATCH(quantize_tensor_per_channel_float_qparams_stub);
9 DEFINE_DISPATCH(dequantize_tensor_per_tensor_affine_stub);
10 DEFINE_DISPATCH(dequantize_tensor_per_channel_affine_stub);
11 DEFINE_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub);
12 DEFINE_DISPATCH(quantize_tensor_per_tensor_affine_sub_byte_stub);
13 DEFINE_DISPATCH(dequantize_tensor_per_tensor_affine_sub_byte_stub);
14 
15 namespace {
16 
checkRoundingMode(const std::string & fn_name)17 void checkRoundingMode(const std::string& fn_name) {
18   // Disabling this warning message for now as it is printed incorrectly. Need
19   // to fix
20 
21   /*  TORCH_WARN_ONCE(
22         std::fegetround() != FE_TONEAREST,
23         fn_name,
24         " current rounding mode is not set to round-to-nearest-ties-to-even
25      (FE_TONEAREST). This will cause accuracy issues in quantized models.");
26   */
27   return;
28 }
29 
checkFloatTensor(const std::string & fn_name,const Tensor & t)30 void checkFloatTensor(const std::string& fn_name, const Tensor& t) {
31   TORCH_CHECK(
32       t.scalar_type() == kFloat, fn_name, " expects a Float Tensor, got ",
33       t.scalar_type());
34 }
35 
checkSameDevice(const std::string & fn_name,const Tensor & t1,const Tensor & t2)36 void checkSameDevice(
37     const std::string& fn_name,
38     const Tensor& t1,
39     const Tensor& t2) {
40   TORCH_CHECK(
41       t1.device() == t2.device(),
42       fn_name,
43       " expects a quantized and float tensors to be on the same device.");
44 }
45 
46 template <typename T>
checkQuantizedTensor(const std::string & fn_name,const Tensor & t)47 void checkQuantizedTensor(const std::string& fn_name, const Tensor& t) {
48   TORCH_CHECK(t.is_quantized(), fn_name, " expects a quantized Tensor.");
49   TORCH_CHECK(
50       t.scalar_type() == caffe2::TypeMeta::Make<T>(),
51       fn_name,
52       " expects a ",
53       caffe2::TypeMeta::Make<T>(),
54       " Tensor, got ",
55       t.scalar_type());
56 }
57 
58 template <typename T>
checkZeroPoint(const std::string & fn_name,int64_t zero_point)59 void checkZeroPoint(const std::string& fn_name, int64_t zero_point) {
60   TORCH_CHECK(
61       zero_point <= std::numeric_limits<T>::max(),
62       fn_name,
63       " zero_point ",
64       zero_point,
65       " is above upper bound.");
66   TORCH_CHECK(
67       zero_point >= std::numeric_limits<T>::min(),
68       fn_name,
69       " zero_point ",
70       zero_point,
71       " is below lower bound.");
72 }
73 
74 template <typename T>
checkZeroPoints(const std::string & fn_name,const Tensor & zero_points)75 void checkZeroPoints(const std::string& fn_name, const Tensor& zero_points) {
76   auto zero_points_data = zero_points.data_ptr<int64_t>();
77   for (const auto i : c10::irange(zero_points.numel())) {
78     checkZeroPoint<T>(fn_name, zero_points_data[i]);
79   }
80 }
81 
checkSameSize(const std::string & fn_name,const Tensor & qt,const Tensor & rt)82 void checkSameSize(
83     const std::string& fn_name,
84     const Tensor& qt,
85     const Tensor& rt) {
86   TORCH_CHECK(
87       qt.sizes().equals(rt.sizes()),
88       fn_name,
89       " only works with Tensors with the same shape");
90 }
91 
checkPerChannelParamsSize(const Tensor & rtensor,int64_t axis,const Tensor & scales,const Tensor & zero_points)92 void checkPerChannelParamsSize(
93     const Tensor& rtensor,
94     int64_t axis,
95     const Tensor& scales,
96     const Tensor& zero_points
97 ) {
98   int64_t channel = rtensor.size(axis);
99   TORCH_CHECK(
100       channel == int64_t(scales.numel()),
101       "length of scales must equal to channel, expected ", channel, " got, ", scales.numel());
102   TORCH_CHECK(
103       channel == int64_t(zero_points.numel()),
104       "length of zero_points must equal to channel expected ", channel, " got, ", zero_points.numel());
105 }
106 
107 } // anonymous namespace
108 
quantize_tensor_per_tensor_affine(const Tensor & rtensor,Tensor & qtensor,double scale,int64_t zero_point)109 Tensor& quantize_tensor_per_tensor_affine(
110     const Tensor& rtensor,
111     Tensor& qtensor,
112     double scale,
113     int64_t zero_point) {
114   static constexpr auto fn_name = "quantize_tensor_per_tensor_affine";
115 
116   checkRoundingMode(fn_name);
117   checkFloatTensor(fn_name, rtensor);
118   checkSameDevice(fn_name, rtensor, qtensor);
119   checkSameSize(fn_name, qtensor, rtensor);
120 
121   AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() {
122     checkQuantizedTensor<scalar_t>(fn_name, qtensor);
123     checkZeroPoint<underlying_t>(fn_name, zero_point);
124   });
125 
126   // Temporary solution to pack the tensor if dtype is torch.quint4x2
127   // Can move this into the fbgemm::Quantize op.
128   if (qtensor.scalar_type() == at::ScalarType::QUInt4x2 || qtensor.scalar_type() == at::ScalarType::QUInt2x4) {
129     quantize_tensor_per_tensor_affine_sub_byte_stub(
130         rtensor.device().type(), rtensor, qtensor, scale, zero_point);
131   } else {
132     quantize_tensor_per_tensor_affine_stub(
133         rtensor.device().type(), rtensor, qtensor, scale, zero_point);
134   }
135   return qtensor;
136 }
137 
quantize_tensor_per_channel_affine(const Tensor & rtensor,Tensor & qtensor,const Tensor & scales,Tensor zero_points,int64_t axis)138 Tensor& quantize_tensor_per_channel_affine(
139     const Tensor& rtensor,
140     Tensor& qtensor,
141     const Tensor& scales,
142     Tensor zero_points,
143     int64_t axis) {
144   static constexpr auto fn_name = "quantize_tensor_per_channel_affine";
145 
146   checkRoundingMode(fn_name);
147   checkFloatTensor(fn_name, rtensor);
148   checkSameDevice(fn_name, rtensor, qtensor);
149   checkSameSize(fn_name, qtensor, rtensor);
150 
151   AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() {
152     checkQuantizedTensor<scalar_t>(fn_name, qtensor);
153     if (qtensor.device().type() != c10::DeviceType::CUDA &&
154         qtensor.device().type() != c10::DeviceType::PrivateUse1) {
155       checkZeroPoints<underlying_t>(fn_name, zero_points);
156     }  // for cuda and privateuse1, this check will occur in the actual device function
157   });
158 
159   TORCH_CHECK(
160       0 <= axis && axis < rtensor.dim(),
161       "Channel axis out of range in per channel affine quantization. Got: ",
162       axis,
163       "Expected: [0, ",
164       rtensor.dim(),
165       ")");
166   checkPerChannelParamsSize(rtensor, axis, scales, zero_points);
167 
168   quantize_tensor_per_channel_affine_stub(
169       rtensor.device().type(), rtensor, qtensor, scales, zero_points, axis);
170   return qtensor;
171 }
172 
quantize_tensor_per_channel_float_qparams(const Tensor & rtensor,Tensor & qtensor,const Tensor & scales,const Tensor & zero_points,int64_t axis)173 Tensor& quantize_tensor_per_channel_float_qparams(
174     const Tensor& rtensor,
175     Tensor& qtensor,
176     const Tensor& scales,
177     const Tensor& zero_points,
178     int64_t axis) {
179   static constexpr auto fn_name =
180       "quantize_tensor_per_channel_float_qparams";
181 
182   checkRoundingMode(fn_name);
183   checkFloatTensor(fn_name, rtensor);
184   checkSameDevice(fn_name, rtensor, qtensor);
185   checkSameSize(fn_name, qtensor, rtensor);
186 
187   AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() {
188     checkQuantizedTensor<scalar_t>(fn_name, qtensor);
189   });
190 
191   TORCH_CHECK(
192       0 <= axis && axis < rtensor.dim(),
193       "Channel axis out of range in per channel float qparams quantization. Got: ",
194       axis,
195       "Expected: [0, ",
196       rtensor.dim(),
197       ")");
198   checkPerChannelParamsSize(rtensor, axis, scales, zero_points);
199 
200   quantize_tensor_per_channel_float_qparams_stub(
201       rtensor.device().type(), rtensor, qtensor, scales, zero_points, axis);
202   return qtensor;
203 }
204 
dequantize_tensor_per_tensor_affine(const Tensor & qtensor,Tensor & rtensor,double scale,int64_t zero_point)205 Tensor& dequantize_tensor_per_tensor_affine(
206     const Tensor& qtensor,
207     Tensor& rtensor,
208     double scale,
209     int64_t zero_point) {
210   static constexpr auto fn_name = "dequantize_tensor_per_tensor_affine";
211   checkFloatTensor(fn_name, rtensor);
212   checkSameDevice(fn_name, rtensor, qtensor);
213   checkSameSize(fn_name, qtensor, rtensor);
214 
215   AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() {
216     checkQuantizedTensor<scalar_t>(fn_name, qtensor);
217     checkZeroPoint<underlying_t>(fn_name, zero_point);
218   });
219 
220   if (qtensor.scalar_type() == at::ScalarType::QUInt4x2 || qtensor.scalar_type() == at::ScalarType::QUInt2x4) {
221     dequantize_tensor_per_tensor_affine_sub_byte_stub(
222         qtensor.device().type(), qtensor, rtensor, scale, zero_point);
223   } else {
224     dequantize_tensor_per_tensor_affine_stub(
225         qtensor.device().type(), qtensor, rtensor, scale, zero_point);
226   }
227   return rtensor;
228 }
229 
dequantize_tensor_per_channel_affine(const Tensor & qtensor,Tensor & rtensor,const Tensor & scales,Tensor zero_points,int64_t axis)230 Tensor& dequantize_tensor_per_channel_affine(
231     const Tensor& qtensor,
232     Tensor& rtensor,
233     const Tensor& scales,
234     Tensor zero_points,
235     int64_t axis) {
236   static constexpr auto fn_name = "dequantize_tensor_per_channel_affine";
237 
238   checkFloatTensor(fn_name, rtensor);
239   checkSameDevice(fn_name, rtensor, qtensor);
240   checkSameSize(fn_name, qtensor, rtensor);
241 
242   AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() {
243     checkQuantizedTensor<scalar_t>(fn_name, qtensor);
244     if(qtensor.device().type() != c10::DeviceType::CUDA &&
245        qtensor.device().type() != c10::DeviceType::PrivateUse1){
246       checkZeroPoints<underlying_t>(fn_name, zero_points);
247     }  // for cuda and privateuse1, this check will occur in the actual device function
248   });
249 
250   TORCH_CHECK(
251       0 <= axis && axis < qtensor.dim(),
252       "Channel axis out of range in per channel affine dequantization. Got:",
253       axis,
254       " Expected: [0, ",
255       qtensor.dim(),
256       ")");
257   checkPerChannelParamsSize(rtensor, axis, scales, zero_points);
258 
259   dequantize_tensor_per_channel_affine_stub(
260       qtensor.device().type(), qtensor, rtensor, scales, zero_points, axis);
261   return rtensor;
262 }
263 
dequantize_tensor_per_channel_float_qparams(const Tensor & qtensor,Tensor & rtensor,const Tensor & scales,const Tensor & zero_points,int64_t axis)264 Tensor& dequantize_tensor_per_channel_float_qparams(
265     const Tensor& qtensor,
266     Tensor& rtensor,
267     const Tensor& scales,
268     const Tensor& zero_points,
269     int64_t axis) {
270   static constexpr auto fn_name = "dequantize_tensor_per_channel_affine";
271 
272   checkFloatTensor(fn_name, rtensor);
273   checkSameDevice(fn_name, rtensor, qtensor);
274   checkSameSize(fn_name, qtensor, rtensor);
275 
276   AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() {
277     checkQuantizedTensor<scalar_t>(fn_name, qtensor);
278   });
279 
280   TORCH_CHECK(
281       0 <= axis && axis < qtensor.dim(),
282       "Channel axis out of range in per channel float qparams dequantization. Got:",
283       axis,
284       " Expected: [0, ",
285       qtensor.dim(),
286       ")");
287   checkPerChannelParamsSize(rtensor, axis, scales, zero_points);
288 
289   dequantize_tensor_per_channel_float_qparams_stub(
290       qtensor.device().type(), qtensor, rtensor, scales, zero_points, axis);
291   return rtensor;
292 }
293 
294 } // namespace at::native
295