xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/AffineQuantizerBase.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/quantized/AffineQuantizerBase.h>
2 #include <c10/util/irange.h>
3 #include <climits>
4 
5 #ifdef USE_FBGEMM
6 #include <fbgemm/QuantUtils.h>
7 #endif
8 #ifdef __ARM_NEON__
9 #include <arm_neon.h>
10 #endif
11 
12 
13 namespace at::native {
14 
15 namespace {
16 
17 template <typename T>
checkZeroPoint(const std::string & fn_name,int64_t zero_point)18 void checkZeroPoint(const std::string& fn_name, int64_t zero_point) {
19   TORCH_CHECK(
20       zero_point <= std::numeric_limits<T>::max(),
21       fn_name,
22       " zero_point ",
23       zero_point,
24       " is out of range.");
25   TORCH_CHECK(
26       zero_point >= std::numeric_limits<T>::min(),
27       fn_name,
28       " zero_point ",
29       zero_point,
30       " is out of range.");
31 }
32 
33 } // anonymous namespace
34 
35 #ifdef USE_FBGEMM
36 // Note: quantize_val is only explicitly used in test outside of this file
37 template <typename T>
quantize_val(double scale,int64_t zero_point,float value)38 T quantize_val(double scale, int64_t zero_point, float value) {
39   // Internally, fbgemm::Quantize uses std::nearbyint.
40   // std::nearbyint results in nearest integer value according to the current
41   // rounding mode and the default rounding mode is rounds to even in half-way
42   // cases in most popular processor architectures like x86 and ARM. This is
43   // typically faster than an alternatives like std::round that rounds half-way
44   // cases away from zero, and can be consistent with SIMD implementations for
45   // example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with
46   // _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode.
47   // NOLINTNEXTLINE(bugprone-signed-char-misuse)
48   auto qvalue = fbgemm::Quantize<typename T::underlying, false /*LEGACY*/>(
49       value,
50       static_cast<int32_t>(zero_point),
51       static_cast<float>(scale),
52       /*result_precision=*/CHAR_BIT * sizeof(typename T::underlying));
53   return static_cast<T>(qvalue);
54 }
55 
56 template <typename T, int precision>
quantize_vec(double scale,int64_t zero_point,const float * src,T * dst,size_t count)57 void quantize_vec(
58     double scale,
59     int64_t zero_point,
60     const float* src,
61     T* dst,
62     size_t count) {
63   fbgemm::Quantize<typename T::underlying, false /*LEGACY*/>(
64       src,
65       (typename T::underlying*)dst,
66       count,
67       fbgemm::TensorQuantizationParams{
68           (float)scale, (int32_t)zero_point, precision});
69 }
70 
71 #if defined(__ARM_NEON__) || defined(__aarch64__)
72 // For use when compiling FBGEMM on aarch64 but still supporting x86
73 // intrinsics via simde
74 template <typename T>
quantize_val_arm(const float scale,const int32_t zero_point,const float value)75 T quantize_val_arm(
76     const float scale,
77     const int32_t zero_point,
78     const float value) {
79   constexpr int32_t qmin = std::numeric_limits<T>::min();
80   constexpr int32_t qmax = std::numeric_limits<T>::max();
81   float inv_scale = 1.0f / scale;
82   auto r = zero_point + static_cast<int32_t>(std::nearbyint(value * inv_scale));
83   r = std::max(r, qmin);
84   r = std::min(r, qmax);
85   return static_cast<T>(r);
86 }
87 
88 template uint8_t quantize_val_arm<uint8_t>(
89     const float scale,
90     const int32_t zero_point,
91     const float value);
92 template int8_t quantize_val_arm<int8_t>(
93     const float scale,
94     const int32_t zero_point,
95     const float value);
96 #endif
97 
98 template <typename T>
dequantize_val(double scale,int64_t zero_point,T value)99 inline float dequantize_val(double scale, int64_t zero_point, T value) {
100   fbgemm::TensorQuantizationParams qparams{};
101   qparams.scale = static_cast<float>(scale);
102   qparams.zero_point = static_cast<int32_t>(zero_point);
103   return fbgemm::Dequantize<typename T::underlying>(value.val_, qparams);
104 }
105 #else // USE_FBGEMM
106 
107 #if defined(__ANDROID__) && !defined(__NDK_MAJOR__)
108 template <class T>
Round(const float x)109 inline float Round(const float x) {
110   return ::nearbyintf(x);
111 }
Round(const double x)112 inline double Round(const double x) {
113   return ::nearbyint(x);
114 }
115 #else
116 template <class T>
Round(const T x)117 inline T Round(const T x) {
118   return std::nearbyint(x);
119 }
120 #endif
121 
122 template <typename T>
quantize_val(double scale,int64_t zero_point,float value)123 T quantize_val(double scale, int64_t zero_point, float value) {
124   // std::nearbyint results in nearest integer value according to the current
125   // rounding mode and the default rounding mode is rounds to even in half-way
126   // cases in most popular processor architectures like x86 and ARM. This is
127   // typically faster than an alternatives like std::round that rounds half-way
128   // cases away from zero, and can be consistent with SIMD implementations for
129   // example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with
130   // _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode.
131   int64_t qvalue;
132   constexpr int64_t qmin = std::numeric_limits<typename T::underlying>::min();
133   constexpr int64_t qmax = std::numeric_limits<typename T::underlying>::max();
134   float inv_scale = 1.0f / static_cast<float>(scale);
135   qvalue = static_cast<int64_t>(zero_point + Round(value * inv_scale));
136   qvalue = std::max<int64_t>(qvalue, qmin);
137   qvalue = std::min<int64_t>(qvalue, qmax);
138   return static_cast<T>(qvalue);
139 }
140 
141 template <typename T>
quantize_val_arm(const float scale,const int32_t zero_point,const float value)142 T quantize_val_arm(
143     const float scale,
144     const int32_t zero_point,
145     const float value) {
146   constexpr int32_t qmin = std::numeric_limits<T>::min();
147   constexpr int32_t qmax = std::numeric_limits<T>::max();
148   float inv_scale = 1.0f / scale;
149 #ifndef _MSC_VER
150   auto r = static_cast<int32_t>(Round(value * inv_scale));
151   // builtin_add_overflow() returns true in case of overflow
152   if (__builtin_add_overflow(zero_point, r, &r)) {
153     // zero_point must be a non-negative value between qmin and qmax,
154     // i.e. only overflow can happen.
155     r = qmax;
156   }
157 #else
158   auto r = zero_point + static_cast<int32_t>(Round(value * inv_scale));
159 #endif
160   r = std::max(r, qmin);
161   r = std::min(r, qmax);
162   return static_cast<T>(r);
163 }
164 
165 template <typename T, int precision>
quantize_vec(double scale,int64_t zero_point,const float * src,T * dst,size_t count)166 void quantize_vec(
167     double scale,
168     int64_t zero_point,
169     const float* src,
170     T* dst,
171     size_t count) {
172   checkZeroPoint<typename T::underlying>("quantize_vec", zero_point);
173   for (const auto i : c10::irange(count)) {
174     dst[i] = quantize_val<T>(scale, zero_point, src[i]);
175   }
176 }
177 
178 template uint8_t quantize_val_arm<uint8_t>(
179     const float scale,
180     const int32_t zero_point,
181     const float value);
182 template int8_t quantize_val_arm<int8_t>(
183     const float scale,
184     const int32_t zero_point,
185     const float value);
186 template <typename T>
dequantize_val(double scale,int64_t zero_point,T value)187 TORCH_API float dequantize_val(double scale, int64_t zero_point, T value) {
188   return static_cast<float>(scale) * (value.val_ - static_cast<int32_t>(zero_point));
189 }
190 #endif // USE_FBGEMM
191 
192 /*
193 * Quantize value based on the following equation
194 * Xq = Round(Xf * inv_scale + zero_point)
195 * where zero_point is in float.
196 *
197 * Note: For the case of embedding quantization we will set zero_point
198 * to (-Xmin/scale), where Xmin is the min value in input tensor row.
199 */
quantize_val_float_qparams(float scale,float zero_point,float value,int qmin,int qmax)200 int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax) {
201 
202   float inv_scale = scale == 0 ? 1.0f : 1.0f / scale;
203   auto qvalue = static_cast<int>(lrintf(value * inv_scale + zero_point));
204   qvalue = std::max(qmin, std::min(qvalue, qmax));
205   return qvalue;
206 }
207 
208 template <typename SRC_T, typename DST_T>
requantize_val(double src_scale,int64_t src_zero_point,double dst_scale,int64_t dst_zero_point,SRC_T src)209 DST_T requantize_val(
210     double src_scale,
211     int64_t src_zero_point,
212     double dst_scale,
213     int64_t dst_zero_point,
214     SRC_T src) {
215   const auto dq = dequantize_val<SRC_T>(src_scale, src_zero_point, src);
216   return quantize_val<DST_T>(dst_scale, dst_zero_point, dq);
217 }
218 
219 template <typename DST_T>
requantize_from_int(double multiplier,int64_t zero_point,int64_t src)220 DST_T requantize_from_int(double multiplier, int64_t zero_point, int64_t src) {
221   int64_t quantize_down =
222       zero_point + lrintf(static_cast<float>(static_cast<double>(src) * multiplier));
223   // NOLINTNEXTLINE(bugprone-signed-char-misuse)
224   int32_t min = std::numeric_limits<typename DST_T::underlying>::min();
225   int32_t max = std::numeric_limits<typename DST_T::underlying>::max();
226   return static_cast<DST_T>(
227       std::min<int64_t>(std::max<int64_t>(quantize_down, min), max));
228 }
229 
230 template TORCH_API qint8
231 quantize_val<qint8>(double scale, int64_t zero_point, float value);
232 template TORCH_API quint8
233 quantize_val<quint8>(double scale, int64_t zero_point, float value);
234 template TORCH_API qint32
235 quantize_val<qint32>(double scale, int64_t zero_point, float value);
236 template TORCH_API void quantize_vec<c10::qint8>(
237     double scale,
238     int64_t zero_point,
239     const float* src,
240     c10::qint8* dst,
241     size_t count);
242 template TORCH_API void quantize_vec<c10::quint8>(
243     double scale,
244     int64_t zero_point,
245     const float* src,
246     c10::quint8* dst,
247     size_t count);
248 template TORCH_API void quantize_vec<c10::qint32, 32>(
249     double scale,
250     int64_t zero_point,
251     const float* src,
252     c10::qint32* dst,
253     size_t count);
254 
255 template TORCH_API float dequantize_val<qint8>(
256     double scale,
257     int64_t zero_point,
258     qint8 value);
259 template TORCH_API float dequantize_val<quint8>(
260     double scale,
261     int64_t zero_point,
262     quint8 value);
263 template TORCH_API float dequantize_val<qint32>(
264     double scale,
265     int64_t zero_point,
266     qint32 value);
267 
268 template TORCH_API qint8
269 requantize_val<qint8, qint8>(double, int64_t, double, int64_t, qint8);
270 template TORCH_API quint8
271 requantize_val<qint8, quint8>(double, int64_t, double, int64_t, qint8);
272 template TORCH_API qint32
273 requantize_val<qint8, qint32>(double, int64_t, double, int64_t, qint8);
274 template TORCH_API qint8
275 requantize_val<quint8, qint8>(double, int64_t, double, int64_t, quint8);
276 template TORCH_API quint8
277 requantize_val<quint8, quint8>(double, int64_t, double, int64_t, quint8);
278 template TORCH_API qint32
279 requantize_val<quint8, qint32>(double, int64_t, double, int64_t, quint8);
280 template TORCH_API qint8
281 requantize_val<qint32, qint8>(double, int64_t, double, int64_t, qint32);
282 template TORCH_API quint8
283 requantize_val<qint32, quint8>(double, int64_t, double, int64_t, qint32);
284 template TORCH_API qint32
285 requantize_val<qint32, qint32>(double, int64_t, double, int64_t, qint32);
286 
287 template TORCH_API qint8 requantize_from_int<qint8>(double, int64_t, int64_t);
288 template TORCH_API quint8
289 requantize_from_int<quint8>(double, int64_t, int64_t);
290 template TORCH_API qint32
291 requantize_from_int<qint32>(double, int64_t, int64_t);
292 
293 } // namespace at::native
294