1 #include <ATen/native/quantized/AffineQuantizerBase.h>
2 #include <c10/util/irange.h>
3 #include <climits>
4
5 #ifdef USE_FBGEMM
6 #include <fbgemm/QuantUtils.h>
7 #endif
8 #ifdef __ARM_NEON__
9 #include <arm_neon.h>
10 #endif
11
12
13 namespace at::native {
14
15 namespace {
16
17 template <typename T>
checkZeroPoint(const std::string & fn_name,int64_t zero_point)18 void checkZeroPoint(const std::string& fn_name, int64_t zero_point) {
19 TORCH_CHECK(
20 zero_point <= std::numeric_limits<T>::max(),
21 fn_name,
22 " zero_point ",
23 zero_point,
24 " is out of range.");
25 TORCH_CHECK(
26 zero_point >= std::numeric_limits<T>::min(),
27 fn_name,
28 " zero_point ",
29 zero_point,
30 " is out of range.");
31 }
32
33 } // anonymous namespace
34
35 #ifdef USE_FBGEMM
36 // Note: quantize_val is only explicitly used in test outside of this file
37 template <typename T>
quantize_val(double scale,int64_t zero_point,float value)38 T quantize_val(double scale, int64_t zero_point, float value) {
39 // Internally, fbgemm::Quantize uses std::nearbyint.
40 // std::nearbyint results in nearest integer value according to the current
41 // rounding mode and the default rounding mode is rounds to even in half-way
42 // cases in most popular processor architectures like x86 and ARM. This is
43 // typically faster than an alternatives like std::round that rounds half-way
44 // cases away from zero, and can be consistent with SIMD implementations for
45 // example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with
46 // _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode.
47 // NOLINTNEXTLINE(bugprone-signed-char-misuse)
48 auto qvalue = fbgemm::Quantize<typename T::underlying, false /*LEGACY*/>(
49 value,
50 static_cast<int32_t>(zero_point),
51 static_cast<float>(scale),
52 /*result_precision=*/CHAR_BIT * sizeof(typename T::underlying));
53 return static_cast<T>(qvalue);
54 }
55
56 template <typename T, int precision>
quantize_vec(double scale,int64_t zero_point,const float * src,T * dst,size_t count)57 void quantize_vec(
58 double scale,
59 int64_t zero_point,
60 const float* src,
61 T* dst,
62 size_t count) {
63 fbgemm::Quantize<typename T::underlying, false /*LEGACY*/>(
64 src,
65 (typename T::underlying*)dst,
66 count,
67 fbgemm::TensorQuantizationParams{
68 (float)scale, (int32_t)zero_point, precision});
69 }
70
71 #if defined(__ARM_NEON__) || defined(__aarch64__)
72 // For use when compiling FBGEMM on aarch64 but still supporting x86
73 // intrinsics via simde
74 template <typename T>
quantize_val_arm(const float scale,const int32_t zero_point,const float value)75 T quantize_val_arm(
76 const float scale,
77 const int32_t zero_point,
78 const float value) {
79 constexpr int32_t qmin = std::numeric_limits<T>::min();
80 constexpr int32_t qmax = std::numeric_limits<T>::max();
81 float inv_scale = 1.0f / scale;
82 auto r = zero_point + static_cast<int32_t>(std::nearbyint(value * inv_scale));
83 r = std::max(r, qmin);
84 r = std::min(r, qmax);
85 return static_cast<T>(r);
86 }
87
88 template uint8_t quantize_val_arm<uint8_t>(
89 const float scale,
90 const int32_t zero_point,
91 const float value);
92 template int8_t quantize_val_arm<int8_t>(
93 const float scale,
94 const int32_t zero_point,
95 const float value);
96 #endif
97
98 template <typename T>
dequantize_val(double scale,int64_t zero_point,T value)99 inline float dequantize_val(double scale, int64_t zero_point, T value) {
100 fbgemm::TensorQuantizationParams qparams{};
101 qparams.scale = static_cast<float>(scale);
102 qparams.zero_point = static_cast<int32_t>(zero_point);
103 return fbgemm::Dequantize<typename T::underlying>(value.val_, qparams);
104 }
105 #else // USE_FBGEMM
106
107 #if defined(__ANDROID__) && !defined(__NDK_MAJOR__)
108 template <class T>
Round(const float x)109 inline float Round(const float x) {
110 return ::nearbyintf(x);
111 }
Round(const double x)112 inline double Round(const double x) {
113 return ::nearbyint(x);
114 }
115 #else
116 template <class T>
Round(const T x)117 inline T Round(const T x) {
118 return std::nearbyint(x);
119 }
120 #endif
121
122 template <typename T>
quantize_val(double scale,int64_t zero_point,float value)123 T quantize_val(double scale, int64_t zero_point, float value) {
124 // std::nearbyint results in nearest integer value according to the current
125 // rounding mode and the default rounding mode is rounds to even in half-way
126 // cases in most popular processor architectures like x86 and ARM. This is
127 // typically faster than an alternatives like std::round that rounds half-way
128 // cases away from zero, and can be consistent with SIMD implementations for
129 // example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with
130 // _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode.
131 int64_t qvalue;
132 constexpr int64_t qmin = std::numeric_limits<typename T::underlying>::min();
133 constexpr int64_t qmax = std::numeric_limits<typename T::underlying>::max();
134 float inv_scale = 1.0f / static_cast<float>(scale);
135 qvalue = static_cast<int64_t>(zero_point + Round(value * inv_scale));
136 qvalue = std::max<int64_t>(qvalue, qmin);
137 qvalue = std::min<int64_t>(qvalue, qmax);
138 return static_cast<T>(qvalue);
139 }
140
141 template <typename T>
quantize_val_arm(const float scale,const int32_t zero_point,const float value)142 T quantize_val_arm(
143 const float scale,
144 const int32_t zero_point,
145 const float value) {
146 constexpr int32_t qmin = std::numeric_limits<T>::min();
147 constexpr int32_t qmax = std::numeric_limits<T>::max();
148 float inv_scale = 1.0f / scale;
149 #ifndef _MSC_VER
150 auto r = static_cast<int32_t>(Round(value * inv_scale));
151 // builtin_add_overflow() returns true in case of overflow
152 if (__builtin_add_overflow(zero_point, r, &r)) {
153 // zero_point must be a non-negative value between qmin and qmax,
154 // i.e. only overflow can happen.
155 r = qmax;
156 }
157 #else
158 auto r = zero_point + static_cast<int32_t>(Round(value * inv_scale));
159 #endif
160 r = std::max(r, qmin);
161 r = std::min(r, qmax);
162 return static_cast<T>(r);
163 }
164
165 template <typename T, int precision>
quantize_vec(double scale,int64_t zero_point,const float * src,T * dst,size_t count)166 void quantize_vec(
167 double scale,
168 int64_t zero_point,
169 const float* src,
170 T* dst,
171 size_t count) {
172 checkZeroPoint<typename T::underlying>("quantize_vec", zero_point);
173 for (const auto i : c10::irange(count)) {
174 dst[i] = quantize_val<T>(scale, zero_point, src[i]);
175 }
176 }
177
178 template uint8_t quantize_val_arm<uint8_t>(
179 const float scale,
180 const int32_t zero_point,
181 const float value);
182 template int8_t quantize_val_arm<int8_t>(
183 const float scale,
184 const int32_t zero_point,
185 const float value);
186 template <typename T>
dequantize_val(double scale,int64_t zero_point,T value)187 TORCH_API float dequantize_val(double scale, int64_t zero_point, T value) {
188 return static_cast<float>(scale) * (value.val_ - static_cast<int32_t>(zero_point));
189 }
190 #endif // USE_FBGEMM
191
192 /*
193 * Quantize value based on the following equation
194 * Xq = Round(Xf * inv_scale + zero_point)
195 * where zero_point is in float.
196 *
197 * Note: For the case of embedding quantization we will set zero_point
198 * to (-Xmin/scale), where Xmin is the min value in input tensor row.
199 */
quantize_val_float_qparams(float scale,float zero_point,float value,int qmin,int qmax)200 int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax) {
201
202 float inv_scale = scale == 0 ? 1.0f : 1.0f / scale;
203 auto qvalue = static_cast<int>(lrintf(value * inv_scale + zero_point));
204 qvalue = std::max(qmin, std::min(qvalue, qmax));
205 return qvalue;
206 }
207
208 template <typename SRC_T, typename DST_T>
requantize_val(double src_scale,int64_t src_zero_point,double dst_scale,int64_t dst_zero_point,SRC_T src)209 DST_T requantize_val(
210 double src_scale,
211 int64_t src_zero_point,
212 double dst_scale,
213 int64_t dst_zero_point,
214 SRC_T src) {
215 const auto dq = dequantize_val<SRC_T>(src_scale, src_zero_point, src);
216 return quantize_val<DST_T>(dst_scale, dst_zero_point, dq);
217 }
218
219 template <typename DST_T>
requantize_from_int(double multiplier,int64_t zero_point,int64_t src)220 DST_T requantize_from_int(double multiplier, int64_t zero_point, int64_t src) {
221 int64_t quantize_down =
222 zero_point + lrintf(static_cast<float>(static_cast<double>(src) * multiplier));
223 // NOLINTNEXTLINE(bugprone-signed-char-misuse)
224 int32_t min = std::numeric_limits<typename DST_T::underlying>::min();
225 int32_t max = std::numeric_limits<typename DST_T::underlying>::max();
226 return static_cast<DST_T>(
227 std::min<int64_t>(std::max<int64_t>(quantize_down, min), max));
228 }
229
230 template TORCH_API qint8
231 quantize_val<qint8>(double scale, int64_t zero_point, float value);
232 template TORCH_API quint8
233 quantize_val<quint8>(double scale, int64_t zero_point, float value);
234 template TORCH_API qint32
235 quantize_val<qint32>(double scale, int64_t zero_point, float value);
236 template TORCH_API void quantize_vec<c10::qint8>(
237 double scale,
238 int64_t zero_point,
239 const float* src,
240 c10::qint8* dst,
241 size_t count);
242 template TORCH_API void quantize_vec<c10::quint8>(
243 double scale,
244 int64_t zero_point,
245 const float* src,
246 c10::quint8* dst,
247 size_t count);
248 template TORCH_API void quantize_vec<c10::qint32, 32>(
249 double scale,
250 int64_t zero_point,
251 const float* src,
252 c10::qint32* dst,
253 size_t count);
254
255 template TORCH_API float dequantize_val<qint8>(
256 double scale,
257 int64_t zero_point,
258 qint8 value);
259 template TORCH_API float dequantize_val<quint8>(
260 double scale,
261 int64_t zero_point,
262 quint8 value);
263 template TORCH_API float dequantize_val<qint32>(
264 double scale,
265 int64_t zero_point,
266 qint32 value);
267
268 template TORCH_API qint8
269 requantize_val<qint8, qint8>(double, int64_t, double, int64_t, qint8);
270 template TORCH_API quint8
271 requantize_val<qint8, quint8>(double, int64_t, double, int64_t, qint8);
272 template TORCH_API qint32
273 requantize_val<qint8, qint32>(double, int64_t, double, int64_t, qint8);
274 template TORCH_API qint8
275 requantize_val<quint8, qint8>(double, int64_t, double, int64_t, quint8);
276 template TORCH_API quint8
277 requantize_val<quint8, quint8>(double, int64_t, double, int64_t, quint8);
278 template TORCH_API qint32
279 requantize_val<quint8, qint32>(double, int64_t, double, int64_t, quint8);
280 template TORCH_API qint8
281 requantize_val<qint32, qint8>(double, int64_t, double, int64_t, qint32);
282 template TORCH_API quint8
283 requantize_val<qint32, quint8>(double, int64_t, double, int64_t, qint32);
284 template TORCH_API qint32
285 requantize_val<qint32, qint32>(double, int64_t, double, int64_t, qint32);
286
287 template TORCH_API qint8 requantize_from_int<qint8>(double, int64_t, int64_t);
288 template TORCH_API quint8
289 requantize_from_int<quint8>(double, int64_t, int64_t);
290 template TORCH_API qint32
291 requantize_from_int<qint32>(double, int64_t, int64_t);
292
293 } // namespace at::native
294